Source code for cryovit.datasets.vit_dataset

"""Dataset class for loading tomograms for DINOv2 models."""

import logging
from pathlib import Path

import h5py
import numpy as np
import torch
import torch.nn.functional as F
from numpy.typing import NDArray
from torch.utils.data import Dataset
from torchvision.transforms import Normalize

from cryovit.config import DINO_PATCH_SIZE

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)


[docs] class VITDataset(Dataset): """Dataset class for Vision Transformer models, loading and processing tomograms."""
[docs] def __init__(self, data_root: Path, records: list[str]) -> None: """Initializes a dataset object to load tomograms, applying normalization and resizing for DINOv2 models. Args: root (Path): Root directory where tomogram files are stored. records (list[str]): A list of strings representing paths to tomogram files in the root directory. """ self.root = ( data_root if isinstance(data_root, Path) else Path(data_root) ) self.records = records self.transform = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) self._printed_resize_warning = False
def __len__(self) -> int: """Returns the number of tomograms in the dataset.""" return len(self.records) def __getitem__(self, idx: int) -> torch.Tensor: """Retrieves a preprocessed tomogram tensor from the dataset by index. Args: idx (int): Index of the tomogram to retrieve. Returns: torch.Tensor: A tensor representing the normalized and resized tomogram. Raises: IndexError: If the index is out of the dataset's bounds. """ if idx >= len(self): raise IndexError record = self.records[idx] data = self._load_tomogram(record) return self._transform(data) def _load_tomogram(self, record: str) -> NDArray[np.float32]: """Loads a tomogram from disk, assuming it is stored as an .hdf file in a `data` key. Args: record (str): The file path to the tomogram relative to the root directory. Returns: NDArray[np.uint8]: The loaded tomogram as a numpy array. """ tomo_path = self.root / record with h5py.File(tomo_path) as fh: data: np.ndarray = fh["data"][()] # type: ignore if data.dtype == np.uint8: data = data.astype(np.float32) / 255.0 return data def _transform(self, data: NDArray[np.float32]) -> torch.Tensor: """Applies normalization and resizing transformations to the tomogram. Args: data (NDArray[np.float32]): The loaded tomogram data as a numpy array. Returns: torch.Tensor: The transformed data as a PyTorch tensor. """ scale = (DINO_PATCH_SIZE / 16, DINO_PATCH_SIZE / 16) _, h, w = data.shape # Resize height and width to be multiples of 16 H = int(np.ceil(h / 16) * 16) W = int(np.ceil(w / 16) * 16) if h != H or w != W: if not self._printed_resize_warning: logging.warning( "Resizing tomogram from %s to %s", (h, w), (H, W) ) self._printed_resize_warning = True data = np.pad(data, ((0, 0), (0, H - h), (0, W - w)), mode="edge") h, w = H, W assert ( h % 16 == 0 and w % 16 == 0 ), f"Invalid height: {h} or width: {w}" np_data = np.expand_dims(data, axis=1) # D, C, H, W (i.e., B, C, H, W) np_data = np.repeat(np_data, 3, axis=1) torch_data = torch.from_numpy( np_data ).float() # data expected to be float already, [0-1] return F.interpolate(torch_data, scale_factor=scale, mode="bicubic")