"""CryoVIT model architecture for 3D tomogram segmentation."""
import torch
from torch import Tensor, nn
from cryovit.models.base_model import BaseModel
from cryovit.types import BatchedTomogramData
[docs]
class CryoVIT(BaseModel):
"""CryoVIT model implementation."""
[docs]
def __init__(self, **kwargs) -> None:
"""Initializes the CryoVIT model with specific convolutional and synthesis blocks."""
super().__init__(**kwargs)
self.layers = nn.Sequential(
nn.Conv3d(
1536, 1024, 1, padding="same"
), # projection to a lower dimension
nn.GELU(),
SynthesisBlock(1024, 192, 128, d1=32, d2=24),
SynthesisBlock(128, 64, 32, d1=16, d2=12),
SynthesisBlock(32, 32, 32, d1=8, d2=4),
SynthesisBlock(32, 16, 8, d1=2, d2=1),
)
# output layer for generating the final segmentation
self.output_layer = nn.Sequential(
nn.Conv3d(8, 8, 3, padding="same"),
nn.GELU(),
nn.Conv3d(8, 1, 3, padding="same"),
)
def forward_volume(self, x: Tensor) -> Tensor:
x = self.layers(x)
x = self.output_layer(x)
x = torch.clip(x, -5.0, 5.0)
return x
[docs]
def forward(self, batch: BatchedTomogramData) -> Tensor: # type: ignore
"""Forward pass for the CryoVIT model."""
x = batch.tomo_batch # (B, D, C, H, W)
x = x.permute(0, 2, 1, 3, 4) # (B, C, D, H, W)
x = self.forward_volume(x)
x = x.squeeze(1) # (B, D, H, W)
return torch.sigmoid(x)
class SynthesisBlock(nn.Module):
"""Synthesis block for anisotropic upscaling with dilated convolutions."""
def __init__(self, c1: int, c2: int, c3: int, d1: int, d2: int) -> None:
"""Initializes the Synthesis block for anisotropic upscaling.
Args:
c1 (int): Number of channels in the input volume.
c2 (int): Number of channels in the intermediate tensor.
c3 (int): Number of channels in the output upscaled volume.
d1 (int): Depthwise dilation rate for the first 3D Conv layer.
d2 (int): Depthwise dilation rate for the second 3D Conv layer.
"""
super().__init__()
self.layers = nn.Sequential(
nn.GroupNorm(max(8, c1 // 8), c1, eps=1e-3),
nn.Conv3d(c1, c2, 3, padding="same", dilation=(d1, 1, 1)),
nn.GELU(),
nn.Conv3d(c2, c2, 3, padding="same", dilation=(d2, 1, 1)),
nn.GELU(),
nn.ConvTranspose3d(
c2, c3, (1, 2, 2), stride=(1, 2, 2)
), # upscale by 2
nn.GELU(),
)
def forward(self, x: Tensor) -> Tensor:
"""Forward pass for the synthesis block."""
return self.layers(x)