Source code for cryovit.models.unet3d

"""UNet3D model architecture for 3D tomogram segmentation."""

import math

import torch
from torch import Tensor, nn

from cryovit.models.base_model import BaseModel
from cryovit.types import BatchedTomogramData


[docs] class UNet3D(BaseModel): """UNet3D model implementation."""
[docs] def __init__(self, **kwargs) -> None: """Initializes the UNet3D model with specific analysis and synthesis blocks.""" super().__init__(**kwargs) self.bottom_layer = nn.Sequential( nn.Conv3d(256, 384, 3, padding="same"), nn.InstanceNorm3d(384, eps=1e-3, affine=True), nn.GELU(), nn.Conv3d(384, 256, 3, padding="same"), nn.InstanceNorm3d(256, eps=1e-3, affine=True), nn.GELU(), ) self.analysis_layers = nn.ModuleList( [ AnalysisBlock(1, 16), AnalysisBlock(16, 64), AnalysisBlock(64, 256), ] ) self.synthesis_layers = nn.ModuleList( [ SynthesisBlock(256, 256, 64), SynthesisBlock(64, 64, 16), SynthesisBlock(16, 16, 16), ] ) self.output_layer = nn.Conv3d(16, 1, 1, padding="same") self.PAD = max(16, 2 ** len(self.analysis_layers))
[docs] def forward_volume(self, x: Tensor) -> Tensor: """Memory optimized forward pass for the UNet3D model.""" analysis_outputs = [] for block in self.analysis_layers: x, prev_x = block(x) analysis_outputs.append(prev_x) for layer in self.bottom_layer: x = layer(x) for block in self.synthesis_layers: for layer in block.upconv: # type: ignore x = layer(x) x = torch.cat([x, analysis_outputs.pop()], 1) for layer in block.layers: # type: ignore x = layer(x) x = self.output_layer(x) x = torch.clip(x, -5.0, 5.0) return x
[docs] def forward(self, batch: BatchedTomogramData) -> Tensor: # type: ignore x = batch.tomo_batch # (B, D, C, H, W) x = x.permute(0, 2, 1, 3, 4) # (B, C, D, H, W) # Add padding new_size = [ self.PAD * math.ceil(dim / self.PAD) for dim in x.size()[-3:] ] D, H, W = x.size()[-3:] if new_size != [D, H, W]: x = self._add_padding(x, new_size) x = self.forward_volume(x) # (B, 1, D, H, W) # Remove padding if new_size != [D, H, W]: x = x[..., :D, :H, :W] x = x.squeeze(1) # (B, D, H, W) return torch.sigmoid(x)
@torch.inference_mode() def _add_padding(self, x: Tensor, new_size: list[int]) -> Tensor: """Adds padding to the input to match the U-Net dimensions""" # pad to multiple of self.PAD D, H, W = x.size()[-3:] new_shape = list(x.shape[:-3]) + new_size x_new = torch.zeros(*new_shape, dtype=x.dtype, device=x.device) x_new[..., :D, :H, :W] = x return x_new
class AnalysisBlock(nn.Module): """Block that performs down-sampling and feature extraction in UNet3D's analysis path.""" def __init__(self, in_channels: int, out_channels: int) -> None: """Initializes the AnalysisBlock with convolutional layers and pooling. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. """ super().__init__() self.pool = nn.Sequential( nn.Conv3d(out_channels, out_channels, 2, stride=2), nn.InstanceNorm3d(out_channels, eps=1e-3, affine=True), nn.GELU(), ) self.layers = nn.Sequential( nn.Conv3d(in_channels, out_channels, 3, padding="same"), nn.InstanceNorm3d(out_channels, eps=1e-3, affine=True), nn.GELU(), nn.Conv3d(out_channels, out_channels, 3, padding="same"), nn.InstanceNorm3d(out_channels, eps=1e-3, affine=True), nn.GELU(), ) def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: """Forward pass for the analysis block.""" x = self.layers(x) y = self.pool(x) return y, x class SynthesisBlock(nn.Module): """Block that performs up-sampling and feature combination in UNet3D's synthesis path.""" def __init__( self, in_channels: int, skip_channels: int, out_channels: int ) -> None: """Initializes the SynthesisBlock with up-sampling and convolutional layers. Args: in_channels (int): Number of input channels. skip_channels (int): Number of channels from the skip connection. out_channels (int): Number of output channels. """ super().__init__() self.upconv = nn.Sequential( nn.ConvTranspose3d(in_channels, out_channels, 2, stride=2), nn.InstanceNorm3d(out_channels, eps=1e-3, affine=True), nn.GELU(), ) self.layers = nn.Sequential( LinearProjection(out_channels + skip_channels, out_channels), nn.InstanceNorm3d(out_channels, eps=1e-3, affine=True), nn.GELU(), nn.Conv3d(out_channels, out_channels, 3, padding="same"), nn.InstanceNorm3d(out_channels, eps=1e-3, affine=True), nn.GELU(), ) def forward(self, x: Tensor, skip_x: Tensor) -> Tensor: """Forward pass for the synthesis block.""" x = self.upconv(x) x = torch.cat([x, skip_x], 1) # channel concat x = self.layers(x) return x class LinearProjection(nn.Module): """Linear projection of an input tensor.""" def __init__(self, in_channels: int, out_channels: int) -> None: """Initializes the LinearProjection with a linear transformation. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels after projection. """ super().__init__() self.proj = nn.Linear(in_channels, out_channels) def forward(self, x: Tensor) -> Tensor: """Forward pass for the linear projection layer.""" x = torch.permute(x, (0, 2, 3, 4, 1)) x = self.proj(x) x = torch.permute(x, (0, 4, 1, 2, 3)) return x