import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from ._layers import LoadMixin, RRDB
[docs]class RRDBNet(nn.Module, LoadMixin):
"""Face quality enhancer.
This model is capable of detecting which images have low-quality
faces, i.e., which images have small face areas compared to the
dimensions of the image and is able to enhance the quality of those
images. The images are up-scaled 4 times and then resized to their
original size - this results in less blurry faces.
This class also inherits ``load`` method from ``LoadMixin`` class.
The method takes a device on which to load the model and loads the
model with a default state dictionary loaded from
``WEIGHTS_FILENAME`` file. It sets this model to eval mode and
disables gradients.
For more information on how RetinaFace model works, see this repo:
`BSRGAN <https://github.com/cszn/BSRGAN>`_. Most of the code was
taken from that repository.
Note:
Whenever an input shape is mentioned, N corresponds to batch
size, C corresponds to the number of channels, H - to input
height, and W - to input width.
"""
#: WEIGHTS_FILENAME (str): The constant specifying the name of
#: ``.pth`` file from which the weights for this model should be
#: loaded. Defaults to "bsrgan_x4_enhancer.pth".
WEIGHTS_FILENAME = "bsrgan_x4_enhancer.pth"
[docs] def __init__(self, min_face_factor: float = 0.001):
"""Initializes RRDB (BSRGAN) model.
Just assigns the minimum face threshold attribute and constructs
module layers for quality inference.
Args:
min_face_factor: The minimum average face factor, i.e., face
area relative to the image, below which the whole image
is enhanced. Defaults to 0.001.
"""
super().__init__()
# Init minimum face factor attribute
self.min_face_factor = min_face_factor
# Initialize first layers that produce features
self.conv_first = nn.Conv2d(3, 64, 3, 1, 1)
self.RRDB_trunk = nn.Sequential(*[RRDB(64, 32) for _ in range(23)])
self.trunk_conv = nn.Conv2d(64, 64, 3, 1, 1)
self.upconv1 = nn.Conv2d(64, 64, 3, 1, 1)
self.upconv2 = nn.Conv2d(64, 64, 3, 1, 1)
# Final layers that produce enhanced image
self.HRconv = nn.Conv2d(64, 64, 3, 1, 1)
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Performs forward pass.
Takes an input tensor which is a batch of images and produces
the same batch, except images are up-scaled 4 times.
Args:
x: The input tensor of shape (N, 3, H, W).
Returns:
An output tensor of shape (N, 3, 4*H, 4*W).
"""
# Perform inference, get features, upscale 2 times, get enhanced
fea = (x := self.conv_first(x)) + self.trunk_conv(self.RRDB_trunk(x))
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2)))
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2)))
return self.conv_last(self.lrelu(self.HRconv(fea)))
[docs] @torch.no_grad()
def predict(
self,
images: torch.Tensor | list[torch.Tensor],
landmarks: np.ndarray | None,
indices: list[int] | None,
) -> torch.Tensor:
"""Enhances the quality of images with low-quality faces.
Takes a batch of images and sets of landmarks for each image and
enhances the quality of those images for which the average face
area factor is lower than ``self.min_face_factor``. The face
factor is computed by dividing the face area (computed by
multiplying the width and the height of the face, specified by
left-eye, right-eye, left-mouth, right-mouth landmark
coordinates) by the image area.
Note:
The images are enhanced one by one instead of as a batch
because the inference is very memory consuming and can
result in memory errors.
Args:
images: Image batch of shape (N, 3, H, W) in RGB form with
float values from 0.0 to 255.0. It must be on the same
device as this model. It can also be a list of tensors
of different shapes.
landmarks: Landmarks batch of shape (``num_faces``, 5, 2)
used to compute average face area for each image. If
None, then every image will be enhanced.
indices: Indices list mapping each set of landmarks to a
specific image in ``images`` batch (multiple sets of
landmarks can come from the same image). If None, then
every image will be enhanced.
Returns:
The same image batch as ``images`` - the shape is
(N, 3, H, W) channels are in RGB and values range from
0.0 to 255.0. The only difference is that some of the images
are of much higher quality, i.e., less blurry.
"""
for i in range(len(images)):
if landmarks is None or indices is None:
# Create a dummy face factor to ensure enhance
face_factor = np.array([self.min_face_factor])
else:
# Select all face landmarks in the current i'th image
landmarks_i = landmarks[[idx == i for idx in indices]]
if len(landmarks_i) == 0:
# No landmarks found
continue
# Compute relative face factor (area face takes up)
[w, h] = (landmarks_i[:, 4] - landmarks_i[:, 0]).T
face_factor = w * h / (images[0].shape[1] * images[0].shape[2])
if face_factor.mean() <= self.min_face_factor:
# Enhance ith img if factor below threshold
image_x4 = self(images[i].unsqueeze(0).div(255))
image_x1 = F.interpolate(image_x4, None, 0.25, "bicubic")
images[i] = image_x1.clamp(0, 1).mul(255).round()[0]
return images