Source code for cryovit.datasets.tomo_dataset

"""Dataset class for loading tomograms for CryoViT experiments."""

from pathlib import Path
from typing import Any

import h5py
import numpy as np
import pandas as pd
from torch.utils.data import Dataset

from cryovit.types import TomogramData


[docs] class TomoDataset(Dataset): """A dataset class for handling and preprocessing tomographic data for CryoVIT models."""
[docs] def __init__( self, records: pd.DataFrame, input_key: str, label_key: str, split_key: str, data_root: Path, aux_keys: list[str] | None = None, train: bool = False, ) -> None: """Initializes a dataset object to load tomograms for model training, applying optional training crops. Args: records (pd.DataFrame): A DataFrame containing records of tomograms. input_key (str): The key in the HDF5 file to access input features. label_key (str): The key in the HDF5 file to access labels. split_key (str): The key in the DataFrame to access the split identifier. data_root (Path): The root directory where the tomograms are stored. aux_keys (Optional[List[str]]): Optional additional keys for auxiliary data to load from the HDF5 files. train (bool): Flag to determine if the dataset is for training (enables transformations). """ if aux_keys is None: aux_keys = [] self.records = records self.input_key = input_key self.label_key = label_key self.split_key = split_key self.aux_keys = aux_keys self.data_root = ( data_root if isinstance(data_root, Path) else Path(data_root) ) self.train = train
def __len__(self) -> int: """Returns the total number of tomograms in the dataset.""" return len(self.records) def __getitem__(self, idx: int) -> TomogramData: # type: ignore """Retrieves a single item from the dataset. Args: idx (int): The index of the item. Returns: TomogramData: A dataclass containing the loaded data, labels, and metadata. Raises: IndexError: If index is out of the range of the dataset. """ if idx >= len(self): raise IndexError record = self.records.iloc[idx] assert isinstance(record, pd.Series) data = self._load_tomogram(record) if self.train: self._random_crop(data) return TomogramData( sample=record["sample"], tomo_name=record["tomo_name"], split_id=data.get("split_id", None), data=data["input"], label=data["label"], aux_data={key: data[key] for key in self.aux_keys if key in data}, ) # type: ignore def _load_tomogram(self, record: pd.Series) -> dict[str, Any]: """Loads a single tomogram based on the record information. Args: record (pd.Series): A series containing the sample and tomogram names. Returns: data (dict[str, Any]): A dictionary with input data, label, and any auxiliary data. """ tomo_path = self.data_root / record["sample"] / record["tomo_name"] # sample, tomo_name, split_id, input, label data_dict = { "sample": record["sample"], "tomo_name": record["tomo_name"], } if self.split_key in record.index: data_dict["split_id"] = record[self.split_key] with h5py.File(tomo_path) as fh: data: np.ndarray = fh[self.input_key][()] # type: ignore if data.dtype == np.uint8: data = data.astype(np.float32) / 255.0 if len(data.shape) == 3: data = data[np.newaxis, ...] # add channel dimension data_dict["input"] = data data_dict["label"] = fh["labels"][self.label_key][()] # type: ignore data_dict |= {key: fh[key][()] for key in self.aux_keys if key in fh} # type: ignore return data_dict def _random_crop(self, data: dict[str, Any]) -> None: """Applies a random crop to the input data in the record dictionary. Args: record (dict[str, Any]): The record dictionary containing 'input' and 'label' data. """ max_depth = 128 side = 32 if self.input_key == "dino_features" else 512 d, h, w = data["input"].shape[-3:] x, y, z = min(d, max_depth), side, side if (d, h, w) == (x, y, z): return # nothing to be done delta_d = d - x + 1 delta_h = h - y + 1 delta_w = w - z + 1 di = np.random.choice(delta_d) if delta_d > 0 else 0 hi = np.random.choice(delta_h) if delta_h > 0 else 0 wi = np.random.choice(delta_w) if delta_w > 0 else 0 data["input"] = data["input"][ ..., di : di + x, hi : hi + y, wi : wi + z ] if self.input_key == "dino_features": hi, wi, y, z = 16 * np.array([hi, wi, y, z]) data["label"] = data["label"][di : di + x, hi : hi + y, wi : wi + z]