Source code for glasses_detector.architectures.tiny_binary_segmenter

import torch
import torch.nn as nn
from torchvision.ops import Conv2dNormActivation


[docs] class TinyBinarySegmenter(nn.Module): """Tiny binary segmenter. This is a custom segmenter created with the aim to contain very few parameters while maintaining a reasonable accuracy. It only has several sequential up-convolution and down-convolution layers with residual connections and is very similar to U-Net. Note: You can read more about U-Net architecture in the following paper by O. Ronneberger et al.: `U-Net: Convolutional Networks for Biomedical Image Segmentation <https://arxiv.org/abs/1505.04597>`_ """ class _Down(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.pool0 = nn.MaxPool2d(2) self.conv1 = Conv2dNormActivation(in_channels, out_channels) self.conv2 = Conv2dNormActivation(out_channels, out_channels) def forward(self, x): return self.conv2(self.conv1(self.pool0(x))) class _Up(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() half_channels = in_channels // 2 self.conv0 = nn.ConvTranspose2d(half_channels, half_channels, 2, 2) self.conv1 = Conv2dNormActivation(in_channels, out_channels) self.conv2 = Conv2dNormActivation(out_channels, out_channels) def forward(self, x1, x2): x1 = self.conv0(x1) diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = nn.functional.pad( x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2) ) x = torch.cat([x2, x1], dim=1) return self.conv2(self.conv1(x)) def __init__(self): super().__init__() # Feature extraction layer self.first = nn.Sequential( Conv2dNormActivation(3, 16), Conv2dNormActivation(16, 16), ) # Down-sampling layers self.down1 = self._Down(16, 32) self.down2 = self._Down(32, 64) self.down3 = self._Down(64, 64) # Up-sampling layers self.up1 = self._Up(128, 32) self.up2 = self._Up(64, 16) self.up3 = self._Up(32, 16) # Pixel-wise classification layer self.last = nn.Conv2d(16, 1, 1)
[docs] def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: """Performs forward pass. Predicts raw pixel scores for the given batch of inputs. Scores are unbounded - anything that's less than 0 means positive class belonging to the pixel is unlikely and anything that's above 0 indicates that positive class for a particular pixel is likely. Args: x (torch.Tensor): Image batch of shape (N, C, H, W). Note that pixel values are normalized and squeezed between 0 and 1. Returns: dict[str, torch.Tensor]: A dictionary with a single "out" entry (for compatibility). The value is an output tensor of shape (N, 1, H, W) indicating which pixels in the image fall under positive category. The scores are unbounded, thus, to convert to probabilities, sigmoid function must be used. """ # Extract primary features x1 = self.first(x) # Downsample features x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) # Updample features x = self.up1(x4, x3) x = self.up2(x, x2) x = self.up3(x, x1) # Predict one channel out = self.last(x) return {"out": out}