Source code for cryovit.models.base_model

"""Base Lightning Module class for 3D Tomogram Segmentation."""

from abc import ABC, abstractmethod
from collections.abc import Callable
from copy import deepcopy
from typing import Literal

import torch
from pytorch_lightning import LightningModule
from pytorch_lightning.utilities import grad_norm
from torch import Tensor, nn
from torch.optim import Optimizer

from cryovit.types import BatchedModelResult, BatchedTomogramData


[docs] class BaseModel(LightningModule, ABC): """Base model with configurable loss functions and metrics."""
[docs] def __init__( self, input_key: str, lr: float, weight_decay: float, losses: dict[str, Callable], metrics: dict[str, Callable], name: str = "BaseModel", custom_kwargs: dict | None = None, **kwargs, ) -> None: """Initializes the BaseModel with specified learning rate, weight decay, loss functions, and metrics. Args: input_key (str): Key to access input data in the batch. lr (float): Learning rate for the optimizer. weight_decay (float): Weight decay factor for AdamW optimizer. losses (dict[str, Callable]): Dictionary of loss functions for training, validation, and testing. metrics (dict[str, Callable]): Dictionary of metric functions for training, validation, and testing. name (str): Name of the model. custom_kwargs (Optional[dict[str, Any]]): Additional custom keyword arguments to set as attributes. """ super().__init__() self.name = name self.input_key = input_key self.lr = lr self.weight_decay = weight_decay if custom_kwargs is not None: for key, value in custom_kwargs.items(): setattr(self, key, value) self.configure_losses(losses) self.configure_metrics(metrics) self.save_hyperparameters()
[docs] def configure_optimizers(self) -> Optimizer: return torch.optim.AdamW( self.parameters(), lr=self.lr, weight_decay=self.weight_decay, )
def configure_losses(self, losses: dict[str, Callable]) -> None: self.loss_fns = losses def configure_metrics(self, metrics: dict[str, Callable]) -> None: self.metric_fns = nn.ModuleDict( { "TRAIN": nn.ModuleDict({m: deepcopy(m_fn) for m, m_fn in metrics.items()}), # type: ignore "VAL": nn.ModuleDict({m: deepcopy(m_fn) for m, m_fn in metrics.items()}), # type: ignore "TEST": nn.ModuleDict({m: deepcopy(m_fn) for m, m_fn in metrics.items()}), # type: ignore } )
[docs] def on_before_optimizer_step(self, optimizer: Optimizer) -> None: """Logs gradient norms just before the optimizer updates weights.""" norms = grad_norm(self, norm_type=2) self.log_dict(norms, on_step=True)
def _masked_predict( self, batch: BatchedTomogramData, use_mito_mask: bool = False # type: ignore ) -> dict[str, Tensor]: """Performs prediction while applying a mask to the inputs and labels based on the label value.""" y_true = batch.labels # (B, D, H, W) y_pred_full = self(batch) # (B, D, H, W) mask = (y_true > -1.0).detach() if use_mito_mask: assert ( batch.aux_data is not None and "labels/mito" in batch.aux_data ), "Batch aux_data must contain 'labels/mito' key for mito masking." # assumes eval code with batch size of 1 mito_mask = torch.from_numpy(batch.aux_data["labels/mito"][0]) > 0 mito_mask = mito_mask.to(dtype=mask.dtype, device=mask.device) mask = mask & mito_mask # Combine masks y_pred = torch.masked_select(y_pred_full, mask).view(-1, 1) y_true = torch.masked_select(y_true, mask).view(-1, 1) return {"preds": y_pred, "labels": y_true, "preds_full": y_pred_full} def compute_losses( self, y_pred: Tensor, y_true: Tensor ) -> dict[str, Tensor]: losses = {k: v(y_pred, y_true) for k, v in self.loss_fns.items()} losses["total"] = sum(losses.values()) return losses
[docs] def log_stats( self, losses: dict[str, Tensor], prefix: Literal["train", "val", "test"], batch_size: int, ) -> None: """Logs computed loss and metric statistics for each training or validation step.""" # Log losses loss_log_dict = {f"{prefix}/loss/{k}": v for k, v in losses.items()} on_step = prefix == "train" self.log_dict( loss_log_dict, prog_bar=True, on_epoch=not on_step, on_step=on_step, batch_size=batch_size, ) # Log metrics metric_log_dict = {} for m, m_fn in self.metric_fns[prefix.upper()].items(): # type: ignore metric_log_dict[f"{prefix}/metric/{m}"] = m_fn self.log_dict( metric_log_dict, prog_bar=True, on_epoch=True, on_step=False, batch_size=batch_size, )
def _do_step(self, batch: BatchedTomogramData, batch_idx: int, prefix: Literal["train", "val", "test"]) -> Tensor: # type: ignore """Processes a single batch of data, computes the loss and updates metrics.""" out = self._masked_predict(batch) y_pred, y_true = out["preds"], out["labels"] losses = self.compute_losses(y_pred, y_true) for _, m_fn in self.metric_fns[prefix.upper()].items(): # type: ignore m_fn(y_pred, y_true) self.log_stats(losses, prefix, batch.num_tomos) return losses["total"]
[docs] def training_step(self, batch: BatchedTomogramData, batch_idx: int) -> Tensor: # type: ignore """Processes one batch during training, returning the total loss.""" return self._do_step(batch, batch_idx, "train") # type: ignore
[docs] def validation_step(self, batch: BatchedTomogramData, batch_idx: int) -> Tensor: # type: ignore """Processes one batch during validation, returning the total loss.""" return self._do_step(batch, batch_idx, "val") # type: ignore
[docs] def test_step(self, batch: BatchedTomogramData, batch_idx: int) -> BatchedModelResult: # type: ignore """Processes one batch during testing, captures predictions, and computes losses and metrics. Args: batch (BatchedTomogramData): The batch of data being processed. batch_idx (int): Index of the batch. Returns: BatchedModelResult: Contains test results and metrics, as well as file metadata for this batch. """ assert ( batch.aux_data is not None and "data" in batch.aux_data ), "Batch aux_data must contain 'data' key for testing." input_data = batch.aux_data["data"] # whether to use mito labels to mask granule predictions in evaluation use_mito_mask = ( "labels/mito" in batch.aux_data and batch.aux_data["labels/mito"] if batch.aux_data is not None else False ) out_dict = self._masked_predict(batch, use_mito_mask=use_mito_mask) y_pred, y_true, y_pred_full = ( out_dict["preds"], out_dict["labels"], out_dict["preds_full"], ) aux_data = {} for key in out_dict: if key not in ["preds", "labels", "preds_full"] and key.endswith( "full" ): aux_data_full = out_dict[key] aux_data[key] = [ t_aux_data.cpu().numpy() for t_aux_data in aux_data_full ] samples, tomo_names = batch.metadata.identifiers split_id = batch.metadata.split_id labels = [t_labels.cpu().numpy() for t_labels in batch.labels] preds = [t_preds.cpu().numpy() for t_preds in y_pred_full] if split_id is not None: split_id = [s_id.item() for s_id in split_id] losses = { k: v.item() for k, v in self.compute_losses(y_pred, y_true).items() } metrics = {} for m, m_fn in self.metric_fns["TEST"].items(): # type: ignore score = m_fn(y_pred, y_true) metrics[m] = score.item() m_fn.reset() return BatchedModelResult( num_tomos=batch.num_tomos, samples=samples, tomo_names=tomo_names, split_id=split_id, data=input_data, label=labels, preds=preds, # type: ignore losses=losses, metrics=metrics, aux_data=aux_data if aux_data else None, )
[docs] def predict_step(self, batch: BatchedTomogramData, batch_idx: int) -> BatchedModelResult: # type: ignore """Processes one batch during prediction, capturing model outputs along with file metadata. Args: batch (BatchedTomogramData): The batch of data being processed. batch_idx (int): Index of the batch. Returns: BatchedModelResult: Contains test results and metrics, as well as file metadata for this batch. """ assert ( batch.aux_data is not None and "data" in batch.aux_data ), "Batch aux_data must contain 'data' key for prediction." input_data = batch.aux_data["data"] preds = self(batch) labels = [t_labels.cpu().numpy() for t_labels in batch.labels] samples, tomo_names = batch.metadata.identifiers return BatchedModelResult( num_tomos=batch.num_tomos, samples=samples, tomo_names=tomo_names, split_id=None, data=input_data, label=labels, preds=[t_preds.cpu().numpy() for t_preds in preds], losses={}, metrics={}, aux_data=None, )
[docs] @abstractmethod def forward(self): """Should be implemented in subclass.""" raise NotImplementedError( "The forward method must be implemented by subclass." )