Source code for face_crop_plus.models.bise

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from ._layers import LoadMixin, ContextPath, FeatureFusionModule, BiSeNetOutput


[docs]class BiSeNet(nn.Module, LoadMixin): """Face attribute parser. This class is capable of predicting scores for 19 attributes for face images. After it identifies the closest attribute for each pixel, it can also put the whole face image to a corresponding attribute or mask group. The 19 attributes are as follows (attributes are indicated from a person's face perspective, meaning, for instance, left eye is the eye on the right hand-side of the picture, however, sides are not always accurate): * 0 - neutral * 1 - skin * 2 - left eyebrow * 3 - right eyebrow * 4 - left eye * 5 - right eye * 6 - eyeglasses * 7 - left ear * 8 - right ear * 9 - earing * 10 - nose * 11 - mouth * 12 - upper lip * 13 - lower lip * 14 - neck * 15 - necklace * 16 - clothes * 17 - hair * 18 - hat Some examples of grouping by attributes: * ``'glasses': [6]`` - this will put each face image that contains pixels labeled as 6 to a category called 'glasses'. * ``'earings_and_necklace': [9, 15]`` - this will put each image that contains pixels labeled as 9 and also contains pixels labeled as 15 to a category called 'earings_and_necklace'. * ``'no_accessories': [-6, -9, -15, -18]`` - this will put each face image that does not contain pixels labeled as either 6, 9, 15, or 18 to a category called 'no_accessories'. Some examples of grouping by mask: * ``'nose': [10]`` - this will put each face image that contains pixels labeled as 10 to a category called 'nose' and generate a corresponding mask. * ``'eyes_and_eyebrows': [2, 3, 4, 5]`` - this will put each image that contains pixels labeled as either 2, 3, 4, or 5 (or any combination of them) to a category called 'eyes_and_eyebrows' and generate a corresponding mask. This class also inherits ``load`` method from ``LoadMixin`` class. The method takes a device on which to load the model and loads the model with a default state dictionary loaded from ``WEIGHTS_FILENAME`` file. It sets this model to eval mode and disables gradients. For more information on how BiSeNet model works, see this repo: `Face Parsing PyTorch <https://github.com/zllrunning/face-parsing.PyTorch>`_. Most of the code was taken from that repository. Note: Whenever an input shape is mentioned, N corresponds to batch size, C corresponds to the number of channels, H - to input height, and W - to input width. Be default, this class initializes the following attributes which can be changed after initialization of the class (but, typically, should not be changed): Attributes: attr_join_by_and (bool): Whether to add a face image to an attribute group if the face meets all the specified attributes in a list (joined by and) of at least one of the attributes (joined by or). Please read the definition of `attr_groups` to get a clearer picture. In most cases, this should be set True - if the attributes in a group list are negative, this will ensure the selected face will match none of the specified attributes. Also, if you want to join the attributes by or (any), then separate single-attribute groups can be created and manually merged into one. Defaults to True. attr_threshold (int): Threshold, based on which the attribute is determined as present in the face image. For instance, if the threshold is 5, then at least 6 pixels must be labeled of the same kind of attribute for that attribute to be considered present in the face image. Defaults to 5. mask_threshold (int): Threshold, based on which the mask is considered to be a proper mask. For instance, if the threshold is 15, then face images for which the number of pixels with the values corresponding to a specified mask group (face attributes) is less than or equal to 15 will be ignored and image-mask pair for that mask category will not be generated. Defaults to 15. mean (list[float]): The list of mean values for each input channel. The pixel values should be shifted by those quantities during inference since this normalization was applied during training. Defaults to [0.485, 0.456, 0.406]. std (list[float]): The list of standard deviation values for each input channel. The pixel values should be scaled by those quantities during inference since this normalization was applied during training. Defaults to [0.229, 0.224, 0.225]. """ #: WEIGHTS_FILENAME (str): The constant specifying the name of #: ``.pth`` file from which the weights for this model should be #: loaded. Defaults to "bise_parser.pth". WEIGHTS_FILENAME = "bise_parser.pth"
[docs] def __init__( self, attr_groups: dict[str, list[int]] | None = None, mask_groups: dict[str, list[int]] | None = None, max_batch_size: int = 8, ): """Initializes BiSeNet model. First it assigns the passed values as attributes. Then this method initializes BiSeNet layers required for face parsing, i.e., labeling face parts. Note: Check class definition for the possible face attribute values and examples of groups. Also note that all the specified variables here are mainly relevant only for :meth:`predict`. Args: attr_groups: Dictionary specifying attribute groups, based on which the face images should be grouped. Each key represents an attribute group name, e.g., 'glasses', 'earings_and_necklace', 'no_accessories', and each value represents attribute indices, e.g., `[6]`, `[9, 15]`, `[-6, -9, -15, -18]`, each index mapping to some attribute. Since this model labels face image pixels, if there is enough pixels with the specified values in the list, the whole face image will be put into that attribute category. For negative values, it will be checked that the labeled face image does not contain those (absolute) values. If it is None, then there will be no grouping according to attributes. Defaults to None. mask_groups: Dictionary specifying mask groups, based on which the face images and their masks should be grouped. Each key represents a mask group name, e.g., 'nose', 'eyes_and_eyebrows', and each value represents attribute indices, e.g., `[10]`, `[2, 3, 4, 5]`, each index mapping to some attribute. Since this model labels face image pixels, a mask will be created with ones at pixels that match the specified attributes and zeros elsewhere. Note that negative values would make no sense here and having them would cause an error. If it is None, then there will be no grouping according to mask groups. Defaults to None. max_batch_size: The maximum batch size used when performing inference. There may be a lot of faces, in a single batch thus splitting to sub-batches for inference and then merging back predictions is a way to deal with memory errors. This is a convenience variable because batch size typically corresponds to the number of images for a single inference, but the input given in :meth:`predict` might have a larger batch size because it represents the number of faces, many of which can come from just a single image. Defaults to 8. """ super().__init__() # Initialize class attributes self.attr_groups = attr_groups self.mask_groups = mask_groups self.batch_size = max_batch_size self.attr_join_by_and = True self.attr_threshold = 5 self.mask_threshold = 10 self.mean = [0.485, 0.456, 0.406] self.std = [0.229, 0.224, 0.225] # Init model layers self.cp = ContextPath() self.ffm = FeatureFusionModule(256, 256) self.conv_out = BiSeNetOutput(256, 256, 19)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Performs forward pass. Takes an input batch and performs inference based on the modules it has. The input is a batch of face images and the output is a corresponding batch of pixel-wise attribute scores. Args: x: The input tensor of shape (N, 3, H, W). Returns: An output tensor of shape (N, 19, H, W) where each channel corresponds to a specific attribute and each value at (H, W) is an unbounded confidence score. """ # Generate final features from layers, upscale feat_out = self.conv_out(self.ffm(*self.cp(x))) return F.interpolate(feat_out, x.size()[2:], None, "bilinear", True)
[docs] def group_by_attributes( self, parse_preds: torch.Tensor, attr_groups: dict[str, list[int]], offset: int, ) -> dict[str, list[int]]: """Groups parse predictions by face attributes. Takes parse predictions for each face where each pixel corresponds to some attribute group (the integer value indicates that group) and extends the groups in attribute dictionary to include more samples that match the group. Args: parse_preds: Face parsing predictions of shape (N, H, W) with integer values indicating pixel categories. attr_groups: The dictionary with keys corresponding to attribute group names (they match ``self.attr_groups`` keys) and values corresponding to indices that map face images from other batches of ``parse_preds`` to the corresponding group. This is the dictionary that is extended and returned. offset: The offset to add to each index. Originally, the indices will correspond only to the face parsings in the current ``parse_preds`` batch and the offset allows to generalize the each index by offsetting it by the previous number of processes face parsings, i.e., the offset is the number of previous batches (``parse_preds``) times the batch size. Returns: Similar to ``attr_groups``, it is the dictionary with the same keys but values (which are lists of indices) may be extended with additional indices. """ # Specify function/criteria to join the attributes in a list att_join = torch.all if self.attr_join_by_and else torch.any for k, v in self.attr_groups.items(): # Get the list of face attributes per group and count pixels attr = torch.tensor(v, device=parse_preds.device).view(1, -1, 1, 1) is_attr = (parse_preds.unsqueeze(1) == attr.abs()).sum(dim=(2, 3)) # Compare each face against each attribute in a group is_attr = att_join(torch.stack([ is_attr[:, i] > self.attr_threshold if a > 0 else is_attr[:, i] <= self.attr_threshold for i, a in enumerate(v) ], dim=1), dim=1) # Add indices of those faces which match the group attribute inds = [i + offset for i in range(len(is_attr)) if is_attr[i]] attr_groups[k].extend(inds) return attr_groups
[docs] def group_by_masks( self, parse_preds: torch.Tensor, mask_groups: dict[str, tuple[list[int], list[np.ndarray]]], offset: int, ) -> dict[str, tuple[list[int], list[np.ndarray]]]: """Groups parse predictions by face mask attributes. Takes parse predictions for each face where each pixel corresponds to some parse/mask group (the integer value indicates that group) and extends the groups in mask dictionary to include more samples that match the group. Args: parse_preds: Face parsing predictions of shape (N, H, W) with integer values indicating pixel categories. mask_groups: The dictionary with keys corresponding to mask group names (they match ``self.mask_groups`` keys) and values corresponding to tuples where the first value is a list of indices that map face images from other batches of ``parse_preds`` to the corresponding group and the second is a list of corresponding masks as numpy arrays of shape (H, W) of type :attr:`numpy.uint8` with 255 at pixels that match the mask group specification and 0 elsewhere. This is the dictionary that is extended and returned. offset: The offset to add to each index. Originally, the indices will correspond only to the face parsings in the current ``parse_preds`` batch and the offset allows to generalize the each index by offsetting it by the previous number of processes face parsings, i.e., the offset is the number of previous batches (``parse_preds``) times the batch size. Returns: Similar to ``mask_groups``, it is the dictionary with the same keys but values (which are tuples of a list of indices and a list of masks) may be extended with additional indices and masks. """ # Retrieve threshold (shorter name) threshold = self.mask_threshold for k, v in self.mask_groups.items(): # Get the list of face attributes per group and check match attr = torch.tensor(v, device=parse_preds.device).view(1, -1, 1, 1) mask = (parse_preds.unsqueeze(1) == attr).any(dim=1) # Identify proper masks and convert them to numpy image type inds = [i for i in range(len(mask)) if mask[i].sum() > threshold] masks = mask[inds].mul(255).cpu().numpy().astype(np.uint8) # Extend the lists of indices and masks for k group mask_groups[k][0].extend([i + offset for i in inds]) mask_groups[k][1].extend([*masks]) return mask_groups
[docs] @torch.no_grad() def predict( self, images: torch.Tensor | list[torch.Tensor], ) -> tuple[dict[str, list[int]] | None, dict[str, tuple[list[int], list[np.ndarray]]] | None]: """Predicts attribute and mask groups for face images. This method takes a batch of face images groups them according to the specifications in ``self.attr_groups`` and ``self.mask_groups``. For more information on how it works, see this class' specification :class:`BiSeNet`. It returns 2 groups maps - one for grouping face images to different attribute categories, e.g., 'with glasses', 'no accessories' and the other for grouping images to different mask groups, e.g., 'nose', 'lips and mouth'. Args: images: Image batch of shape (N, 3, H, W) in RGB form with float values from 0.0 to 255.0. It must be on the same device as this model. A list of tensors can also be provided, however, they all must have the same spatial dimensions to be stack-able to a single batch. Returns: A tuple of 2 dictionaries (either can be None): 1. ``attr_groups`` - each key represents attribute category and each value is a list of indices indicating which samples from ``images`` batch belong to that category. It can be None if ``self.attr_groups`` is None. 2. `mask_groups` - each key represents attribute (mask) category and each value is a tuple where the first element is a list of indices indicating which samples from ``images`` batch belong to that mask group and the second element is a corresponding batch of masks of shape (N, H, W) of type :attr:`numpy.uint8` with values of either 0 or 255. The masks are presented in that order as the indices indicate which face images to take for that mask group. It can be None if ``self.mask_groups`` is None. """ # Initialize groups as None, a helper offset attr_groups, mask_groups, offset = None, None, 0 if self.attr_groups is not None: # Initialize an empty dictionary of attribute groups attr_groups = {k: [] for k in self.attr_groups.keys()} if self.mask_groups is not None: # Initialize an empty dictionary of mask groups mask_groups = {k: ([], []) for k in self.mask_groups.keys()} if isinstance(images, list): # Stack the list of tensors images = torch.stack(images) # Convert mean and std to tensors and reshape, resize images mean = torch.tensor(self.mean, device=images.device).view(1, 3, 1, 1) std = torch.tensor(self.std, device=images.device).view(1, 3, 1, 1) x = F.interpolate(images.div(255), (512, 512), mode="bilinear") for sub_x in torch.split(x, self.batch_size): # Inference and resize back o = self((sub_x - mean) / std) o = F.interpolate(o, images.size()[2:], mode="nearest").argmax(1) if self.attr_groups is not None: # Extend each attribute group based on predictions attr_groups = self.group_by_attributes(o, attr_groups, offset) if self.mask_groups is not None: # Extend each mask group based on predictions mask_groups = self.group_by_masks(o, mask_groups, offset) # Increment offset offset += len(sub_x) if attr_groups is not None: # Filter out groups for which the list of indices is empty attr_groups = {k: v for k, v in attr_groups.items() if len(v) > 0} if mask_groups is not None: # Filter out groups for which the list of indices is empty mask_groups = { k: (v[0], np.stack(v[1])) for k, v in mask_groups.items() if len(v[1]) > 0 } return attr_groups, mask_groups