Source code for glasses_detector.components.base_model

import inspect
import os
import warnings
from abc import abstractmethod
from dataclasses import dataclass, field, fields
from typing import (
    Any,
    Callable,
    ClassVar,
    Collection,
    Iterable,
    Self,
    overload,
    override,
)

import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torchvision.transforms.v2.functional import normalize, to_dtype, to_image

from ..utils import FilePath, copy_signature, eval_infer_mode, is_path_type, is_url
from .pred_interface import PredInterface
from .pred_type import Default


[docs] @dataclass class BaseGlassesModel(PredInterface): """Base class for all glasses models. Base class with common functionality, i.e., prediction and weight loading methods, that should be inherited by all glasses models. Child classes must implement :meth:`.create_model` method which should return the model architecture based on :attr:`.model_info` which is a dictionary containing the model name and the release version. The dictionary depends on the model's :attr:`kind` and :attr:`size`, both of which are used when creating an instance. An instance can be created by providing a custom model instead of creating a predefined one, see :meth:`from_model`. Note: When ``weights`` is :data:`True`, the URL of the weights to be downloaded from will be constructed automatically based on :attr:`model_info`. According to :func:`~torch.hub.load_state_dict_from_url`, first, the corresponding weights will be checked if they are already present in the hub cache, which by default is ``~/.cache/torch/hub/checkpoints``, and, if they are not, the weight will be downloaded there and then loaded. Important: To train the actual model parameters, i.e., the model of type :class:`torch.nn.Module`, retrieve it using the :attr:`model` attribute. Args: task (str): The task the model is built for. Used when automatically constructing URL to download the weights from. kind (str): The kind of the model. Used to access :attr:`.model_info`. size (str): The size of the model. Used to access :attr:`.model_info`. weights (bool | str | None, optional): Whether to load the pre-trained weights from a custom URL (or a local file if they're already downloaded) which will be inferred based on model's :attr:`task`, :attr:`kind`, and :attr:`size`. If a string is provided, it will be used as a path or a URL (determined automatically) to the model weights. Defaults to :data:`False`. device (str | torch.device | None, optional): Device to cast the model to (once it is loaded). If specified as :data:`None`, it will be automatically checked if `CUDA <https://developer.nvidia.com/cuda-toolkit>`_ or `MPS <https://developer.apple.com/documentation/metalperformanceshaders>`_ is supported. Defaults to :data:`None`. """ task: str kind: str size: str weights: bool | str | None = False device: str | torch.device | None = None model: nn.Module = field(init=False, repr=False) BASE_WEIGHTS_URL: ClassVar[str] = ( "https://github.com/mantasu/glasses-detector/releases/download" ) """ typing.ClassVar[str]: The base URL to download the weights from. """ ALLOWED_SIZE_ALIASES: ClassVar[set[str]] = { "small": {"small", "little", "s"}, "medium": {"medium", "normal", "m"}, "large": {"large", "big", "l"}, } """ typing.ClassVar[set[str]]: The set of allowed sizes and their aliases for the model. These are used to convert an alias to a standard size when accessing :attr:`.model_info` . Available aliases are: +------------------------------+---------------------------------+---------------------------+ | Small | Medium | Large | +==============================+=================================+===========================+ | ``small``, ``little``, ``s`` | ``medium``, ``normal``, ``m``, | ``large``, ``big``, ``l`` | +------------------------------+---------------------------------+---------------------------+ Note: Any case is acceptable, for example, **Small** can be specified as ``"small"``, ``"S"``, ``"SMALL"``, ``"Little"``, etc. :meta hide-value: """ DEFAULT_SIZE_MAP: ClassVar[dict[str, dict[str, str]]] = { "<size>": {"name": "<architecture-name>", "version": "<version>"} } """ typing.ClassVar[dict[str, dict[str, str]]]: The default size map from the size of the model to the model info dictionary which contains the name of the architecture and the version of the weights release. This is just a helper component for :attr:`DEFAULT_KIND_MAP` because each default kind has the same set of default models. Example: .. code-block:: python >>> [info["name"] for info in DEFAULT_SIZE_MAP.values()] # list of all the available architectures :meta hide-value: """ DEFAULT_KIND_MAP: ClassVar[dict[str, dict[str, dict[str, str]]]] = { "<kind>": DEFAULT_SIZE_MAP, } """ typing.ClassVar[dict[str, dict[str, dict[str, str]]]]: The default map from model :attr:`kind` and :attr:`size` to the model info dictionary. The model info is used to construct the URL to download the weights from. The nested dictionary has 3 levels which are expected to be as follows: 1. ``kind`` - the kind of the model 2. ``size`` - the size of the model 3. ``info`` - the model info, i.e., ``"name"`` and ``"version"`` Example: .. code-block:: python >>> DEFAULT_KIND_MAP["<kind>"]["<size>"] {'name': '<architecture-name>', 'version': '<release-version>'} :meta hide-value: """ def __post_init__(self): super().__init__() try: # Get the model name and create it model_name = self.model_info["name"] self.model = self.create_model(model_name) except KeyError: # Raise model info warning self._model_info_warning() except ValueError: # Raise model init (structure construction) warning message = f"Model structure named {model_name} does not exist. " self._model_init_warning(message=message) if self.device is None and torch.cuda.is_available(): # Set device to CUDA if available self.device = torch.device("cuda") elif self.device is None and torch.backends.mps.is_available(): # Set device to MPS if available self.device = torch.device("mps") elif self.device is None: # Set device to CPU by default self.device = torch.device("cpu") if self.weights: # Load weights if True or path is specified self.load_weights(path_or_url=self.weights) # Cast to device self.model.to(self.device) @property def model_info(self) -> dict[str, str]: """Model info property. This contains the information about the model used (e.g., architecture and weights). By default, it should have 2 fields: ``"name"`` and ``"version"``, both of which are used when initializing the architecture and looking for pretrained weights (see :meth:`load_weights`). Note: This is the default implementation which accesses :attr:`DEFAULT_KIND_MAP` based on :attr:`kind` and :attr:`size`. Child classes can override either :attr:`DEFAULT_KIND_MAP` or this property itself for a custom dictionary. Returns: dict[str, str]: The model info dictionary with 2 fields - ``"name"`` and ``"version"`` which allow to construct model architecture and download the pretrained model weights, if present. """ match self.size.lower(): case alias if alias in self.ALLOWED_SIZE_ALIASES["small"]: # Set to small size = "small" case alias if alias in self.ALLOWED_SIZE_ALIASES["medium"]: # Set to medium size = "medium" case alias if alias in self.ALLOWED_SIZE_ALIASES["large"]: # Set to large size = "large" case _: # Don't change size = self.size return self.DEFAULT_KIND_MAP.get(self.kind, {}).get(size, {})
[docs] @staticmethod @abstractmethod def create_model(self, model_name: str) -> nn.Module: """Creates the model architecture. Takes the name of the model architecture and returns the corresponding model instance. Args: model_name (str): The name of the model architecture to create. For available architectures, see the class description (**Size Information** table) or :attr:`DEFAULT_SIZE_MAP`. Returns: torch.nn.Module: The model instance with the corresponding architecture. Raises: ValueError: If the architecture for the model name is not implemented or is not valid. """ ...
[docs] @classmethod def from_model(cls, model: nn.Module, **kwargs) -> Self: """Creates a glasses model from a custom :class:`torch.nn.Module`. Creates a glasses model wrapper for a custom provided :class:`torch.nn.Module`, instead of creating a predefined one based on :attr:`kind` and :attr:`size`. Note: Make sure the provided model's ``forward`` method behaves as expected, i.e., returns the prediction in expected format for compatibility with :meth:`predict`. Warning: :attr:`model_info` property will not be useful as it would return an empty dictionary for custom specified :attr:`kind` and :attr:`size` (if specified at all). Args: model (torch.nn.Module): The custom model that will be assigned as :attr:`model`. **kwargs: Keyword arguments to pass to the constructor; check the documentation of this class for more details. If ``task``, ``kind``, and ``size`` are not provided, they will be set to ``"custom"``. If the model architecture is custom, you may still specify the path to the pretrained wights via ``weights`` argument. Finally, if ``device`` is not provided, the model will remain on the same device as is. Returns: typing.Self: The glasses model wrapper of the same class type from which this method was called for the provided custom model. """ # Set default values for class args kwargs.setdefault("task", "custom") kwargs.setdefault("kind", "custom") kwargs.setdefault("size", "custom") kwargs.setdefault("device", device := next(iter(model.parameters())).device) # Weights will be handled after instantiation weights = kwargs.get("weights", False) kwargs["weights"] = False # Filter out the arguments that are not for the model init is_init = {f.name: f.init for f in fields(cls)} kwargs = {k: v for k, v in kwargs.items() if is_init[k]} with warnings.catch_warnings(): # Ignore warnings from model init warnings.simplefilter("ignore") glasses_model = cls(**kwargs) # Assign the actual model glasses_model.model = model if weights := kwargs.get("weights", False): # Load weights if `weights` is True or a path glasses_model.load_weights(path_or_url=weights) # Cast to device glasses_model.model.to(device) return glasses_model
@overload def predict( self, image: FilePath | Image.Image | np.ndarray, format: ( Callable[[Any], Default] | Callable[[Image.Image, Any], Default] ) = lambda x: str(x), input_size: tuple[int, int] | None = (256, 256), ) -> Default: ... @overload def predict( self, image: Collection[FilePath | Image.Image | np.ndarray], format: ( Callable[[Any], Default] | Callable[[Image.Image, Any], Default] ) = lambda x: str(x), input_size: tuple[int, int] | None = (256, 256), ) -> list[Default]: ...
[docs] @override def predict( self, image: ( FilePath | Image.Image | np.ndarray | Collection[FilePath | Image.Image | np.ndarray] ), format: ( Callable[[Any], Default] | Callable[[Image.Image, Any], Default] ) = lambda x: str(x), input_size: tuple[int, int] | None = (256, 256), ) -> Default | list[Default]: """Predicts based on the model specified by the child class. Takes a path or multiple paths to image files or the loaded images themselves and outputs a formatted prediction generated by the child class. Note: This method expects that :meth:`forward` always returns an :class:`~typing.Iterable` of any type of predictions (typically, they would be of type :class:`~torch.Tensor`), even if there is only one prediction. Likewise, :class:`~torch.Tensor` representing a batch of loaded images is passed to :meth:`forward` when generating those predictions. Important: If the image is provided as :class:`numpy.ndarray`, make sure the last dimension specifies the channels, i.e., last dimension should be of size ``1`` or ``3``. If it is anything else, e.g., if the shape is ``(3, H, W)``, where ``W`` is neither ``1`` nor ``3``, this would be interpreted as 3 grayscale images. .. seealso:: :meth:`forward` Args: image (FilePath | PIL.Image.Image | numpy.ndarray | typing.Collection[FilePath | PIL.Image.Image | numpy.ndarray]): The path(-s) to the image to generate the prediction for or the image(-s) itself represented as :class:`Image.Image` or as a :class:`numpy.ndarray`. Note that the image should have values between 0 and 255 and be of RGB format. Normalization is not needed as the channels will be automatically normalized before passing through the network. format (typing.Callable[[typing.Any], Default] | (typing.Callable[[PIL.Image.Image, typing.Any], Default], optional): Format callback. This is a custom function that takes the predicted elements from the iterable output of :meth:`forward` (elements are usually of type :class:`~torch.Tensor`) as input or the original image and its prediction as inputs (it will be determined automatically which function it is) and outputs a formatted prediction of type :attr:`Default`. Defaults to ``lambda x: str(x)``. input_size (tuple[int, int] | None, optional): The size (width, height), or ``(W, H)``, to resize the image to before passing it through the network. If :data:`None`, the image will not be resized. It is recommended to resize it to the size the model was trained on, which by default is ``(256, 256)``. Defaults to ``(256, 256)``. Returns: Default | list[Default]: The formatted prediction or a list of formatted predictions if multiple images were provided. """ # Init mean + std (default from albumentations) and others mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] device = next(iter(self.model.parameters())).device xs, preds, is_multiple = [], [], True # Warning: if the image has shape (3, H, W), # it will be interpreted as 3 grayscale images if (is_path_type(image) or isinstance(image, Image.Image)) or ( isinstance(image, np.ndarray) and (image.ndim == 2 or (image.ndim == 3 and image.shape[-1] in [1, 3])) ): # Single image image = [image] is_multiple = False if require_original := (len(inspect.signature(format).parameters) == 2): # Init original images original_images = [] for img in image: if isinstance(img, (str, bytes, os.PathLike)): # Load from the path and ensure RGB img = Image.open(img).convert("RGB") elif isinstance(img, np.ndarray): # Convert to PIL image and ensure RGB img = Image.fromarray(img).convert("RGB") if require_original: # Keep track of original original_images.append(img) if input_size is not None: # Resize the image img = img.resize(input_size) # Convert to tensor, normalize; add to xs x = to_dtype(to_image(img), scale=True) xs.append(normalize(x, mean=mean, std=std)) with eval_infer_mode(self.model): # Stack and cast to device; perform forward pass pred = self.forward(torch.stack(xs).to(device)) for i, pred in enumerate(pred): if require_original: # Format prediction with original image preds.append(format(original_images[i], pred)) else: # Append formatted prediction preds.append(format(pred)) return preds if is_multiple else preds[0]
[docs] def forward(self, x: torch.Tensor) -> Iterable[Any]: """Performs forward pass. Calls the forward method of the inner :attr:`model`, by passing a batch of images as its first argument. Tip: If this method is used during inference, make sure to set the model to evaluation mode and enable :class:`~torch.inference_mode`, e.g., via :class:`.eval_infer_mode` decorator/context manager. Note: The default :meth:`predict` that uses this method assumes an input is a batch of images of type :class:`~torch.Tensor` and the output can be anything that is :class:`~typing.Iterable`, e.g., a :class:`~torch.Tensor`. Warning: In case of a custom inner :attr:`model` (e.g., if the instance was created using :meth:`from_model`) that does not accept a tensor representing a batch of images as its first argument, this method will not work, in which case :meth:`predict` will also not work. .. seealso:: :meth:`predict` Args: x: A batch of images - a :class:`~torch.Tensor` of shape ``(N, C, H, W)`` with normalized pixel values between ``0`` and ``1``. Returns: An iterable of predictions (one for each input). Usually, it is a :class:`~torch.Tensor` with the first dimension of size ``N`` which is the batch size of the original input. """ return self.model(x)
[docs] def load_weights(self, path_or_url: str | bool = True): """Loads inner :attr:`model` weights. Takes a path of a URL to the weights file, or :data:`True` to construct the URL automatically based on :attr:`model_info` and loads the weights into :attr:`model`. Note: If the weights are already downloaded, they will be loaded from the hub cache, which by default is ``~/.cache/torch/hub/checkpoints``. Warning: If the fields in :attr:`model_info` are not recognized, e.g., by providing an unrecognized :attr:`kind` or :attr:`size` or by initializing with :meth:`from_model`, this method will not be able to construct the URL (if ``path_or_url`` is :data:`True`) and will raise a warning. Args: path_or_url (str | bool, optional): The path or the URL (it will be inferred automatically) to the model weights (``.pth`` file). It can also be :class:`bool`, in which case :data:`True` indicates to construct ``URL`` for the pre-trained weights and :data:`False` does nothing. Defaults to :data:`True`. """ if isinstance(path_or_url, bool) and path_or_url: try: # Get model name and release version name = self.model_info["name"] version = self.model_info["version"] except KeyError: # Raise model info warning for not constructing URL message = "Path/URL to weights cannot be constructed. " self._model_info_warning(message) return if self.size.lower() in self.ALLOWED_SIZE_ALIASES["large"]: raise NotImplementedError("Large models are not supported yet") # Construct weights URL from base URL and model info weights_name = f"{self.task}_{self.kind}_{name}.pth" path_or_url = f"{self.BASE_WEIGHTS_URL}/{version}/{weights_name}" elif isinstance(path_or_url, bool): return if self.model is None: # Raise model init warning for not loading weights message = "Cannot load weights for the unspecified model. " self._model_init_warning(message) return if is_url(path_or_url): # Get weights from download path (and download if needed) weights = torch.hub.load_state_dict_from_url( url=path_or_url, map_location=self.device, ) else: # Load weights from local path weights = torch.load(path_or_url, map_location=self.device) # Actually load the weights self.model.load_state_dict(weights) if self.device is not None: # Cast model to device self.model.to(self.device)
def _model_info_warning(self, message: str = ""): warnings.warn( f"{message}Model info (name and release version) not found for the " f"specified configuration: {self.task=} {self.kind=} {self.size=}." ) def _model_init_warning(self, message: str = ""): warnings.warn( f"{message}Model is not initialized. Try assigning a custom model " f"via `self.model` attribute, for instance, create a custom model " f"using `GlassesModel.create_model` and assign it." ) @copy_signature(predict) def __call__(self, *args, **kwargs): return self.predict(*args, **kwargs)
if __name__ == "__main__": model = BaseGlassesModel( task="binary", kind="detection", size="eyes-large", )