Source code for cryovit.run.dino_features

"""Functions to load data and run DINOv2 feature extraction."""

import logging
from pathlib import Path

import h5py
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from hydra import compose, initialize
from hydra.utils import instantiate
from numpy.typing import NDArray
from rich.progress import track

from cryovit.config import (
    DEFAULT_WINDOW_SIZE,
    DINO_PATCH_SIZE,
    BaseDataModule,
    DinoFeaturesConfig,
    samples,
    tomogram_exts,
)
from cryovit.types import FileData
from cryovit.visualization.dino_pca import export_pca

torch.set_float32_matmul_precision("high")  # ensures tensor cores are used
dino_model = (
    "facebookresearch/dinov2",
    "dinov2_vitg14_reg",
)  # the giant variant of DINOv2


def _folded_to_patch(
    x: torch.Tensor, C: int, window_size: tuple[int, int]
) -> tuple[torch.Tensor, int, int]:
    """Convert unfolded tensor to patch features tensor shape."""

    B, _, L = x.shape  # B, C * window_size * window_size, L
    x = x.permute(0, 2, 1).contiguous()  # B, L, C * window_size * window_size
    x = x.reshape(
        -1, C, window_size[0], window_size[1]
    )  # B * L, C, window_size, window_size
    return x, B, L


def _patch_to_folded(x: torch.Tensor, B: int, L: int) -> torch.Tensor:
    """Convert patch features tensor to folded tensor shape."""

    x = x.reshape(
        B, L, x.shape[1], x.shape[2]
    )  # B, L, window_size * window_size, C2
    x = x.permute(
        0, 3, 2, 1
    ).contiguous()  # B, C2, window_size * window_size, L
    x = x.reshape(B, -1, L)  # B, C2 * window_size * window_size, L
    return x


@torch.inference_mode()
def _dino_features(
    data: torch.Tensor,
    model: torch.nn.Module,
    batch_size: int,
    window_size: tuple[int, int] | int | None = DEFAULT_WINDOW_SIZE,
) -> NDArray[np.float16]:
    """Extract patch features from a tomogram using a DINOv2 model. Uses windows to fit everything in memory.

    Args:
        data (torch.Tensor): The input data tensors containing the tomogram's data.
        model (nn.Module): The pre-loaded DINOv2 model used for feature extraction.
        batch_size (int): The maximum number of 2D slices of a tomograms processed in each batch.
        window_size (Optional[tuple[int, int] | int]): The size of the windows to use for feature extraction.

    Returns:
        NDArray[np.float16]: A numpy array containing the extracted features in reduced precision.
    """

    if window_size is None:
        window_size = (DEFAULT_WINDOW_SIZE, DEFAULT_WINDOW_SIZE)
    elif isinstance(window_size, int):
        window_size = (window_size, window_size)
    C, H, W = data.shape[1:]
    ph, pw = (
        H // DINO_PATCH_SIZE,
        W // DINO_PATCH_SIZE,
    )  # the number of patches extracted per slice
    window_size = (min(window_size[0], H), min(window_size[1], W))
    stride = tuple(
        (ws // 2) // DINO_PATCH_SIZE * DINO_PATCH_SIZE for ws in window_size
    )  # 50% overlap, multiple of 14
    fold = nn.Fold(
        output_size=(ph, pw),
        kernel_size=tuple(ws // DINO_PATCH_SIZE for ws in window_size),
        stride=tuple(s // DINO_PATCH_SIZE for s in stride),
    )
    unfold = nn.Unfold(
        kernel_size=window_size,
        stride=stride,
    )
    # divisor tensor for averaging overlapping patches
    divisor = torch.ones(1, 1, H, W).cuda()
    divisor = unfold(divisor)
    divisor, B, L = _folded_to_patch(divisor, 1, window_size)
    divisor = F.max_pool2d(
        divisor, DINO_PATCH_SIZE, stride=DINO_PATCH_SIZE
    )  # simulate patch extraction
    divisor = divisor.flatten(2).transpose(1, 2)  # B, L, 1
    divisor = _patch_to_folded(divisor, B, L)
    divisor = fold(divisor)

    # Fold data into overlapping windows
    data = unfold(data)  # B, C * window_size * window_size, L
    data, B, L = _folded_to_patch(
        data, C, window_size
    )  # B * L, C, window_size, window_size
    # Calculate features in batches
    all_features = []
    num_batches = (len(data) + batch_size - 1) // batch_size
    for i, batch_idx in enumerate(range(0, len(data), batch_size)):
        # Check for overflows
        if batch_idx + batch_size > len(data):
            batch_size = len(data) - batch_idx
        logging.debug(
            "Processing batch %d/%d: %d -> %d",
            i + 1,
            num_batches,
            batch_idx,
            (batch_idx + batch_size),
        )
        vec = data[batch_idx : batch_idx + batch_size].cuda().float()
        # Convert to RGB by repeating channels
        if vec.shape[1] == 1:
            vec = vec.repeat(1, 3, 1, 1)
        # Extract features for each patch
        patch_features = model.forward_features(vec)[  # type: ignore
            "x_norm_patchtokens"
        ]  # B', ph * pw, C2
        all_features.append(patch_features.half())
    features = torch.cat(all_features, dim=0)  # B * L, ph * pw, C2
    features = _patch_to_folded(features, B, L)  # B, C2 * ph * pw, L
    features = fold(features) / divisor  # B, C2, ph, pw
    features = features.permute([1, 0, 2, 3]).contiguous()  # C2, D, W, H
    features = features.cpu().numpy()
    return features


def _save_data(
    data: dict[str, NDArray[np.uint8]],
    features: NDArray[np.float16],
    tomo_name: str,
    dst_dir: Path,
) -> None:
    """Save extracted features to a specified directory as an .hdf file, copying the source tomogram data."""

    dst_dir.mkdir(parents=True, exist_ok=True)

    with h5py.File(dst_dir / tomo_name, "w") as fh:
        for key in data:
            if key != "data":
                if "labels" not in fh:
                    fh.create_group("labels")
                fh["labels"].create_dataset(key, data=data[key], shape=data[key].shape, dtype=data[key].dtype, compression="gzip")  # type: ignore
            else:
                fh.create_dataset(
                    "data",
                    data=data[key],
                    shape=data[key].shape,
                    dtype=data[key].dtype,
                    compression="gzip",
                )
        fh.create_dataset(
            "dino_features",
            data=features,
            shape=features.shape,
            dtype=features.dtype,
        )


def _process_sample(
    src_dir: Path,
    dst_dir: Path,
    csv_dir: Path,
    model: torch.nn.Module,
    sample: str,
    datamodule: BaseDataModule,
    batch_size: int,
    image_dir: Path | None,
):
    """Process all tomograms in a single sample by extracting and saving their DINOv2 features."""

    tomo_dir = src_dir / sample
    result_dir = dst_dir / sample
    csv_file = csv_dir / f"{sample}.csv"
    # If no .csv file, use all tomograms in the dataset
    if not csv_file.exists():
        records = [
            f.name for f in tomo_dir.glob("*") if f.suffix in tomogram_exts
        ]
    else:
        records = pd.read_csv(csv_file)["tomo_name"].to_list()

    dataset = instantiate(datamodule.dataset, data_root=tomo_dir)(
        records=records
    )
    dataloader = instantiate(datamodule.dataloader)(dataset=dataset)

    for i, x in track(
        enumerate(dataloader),
        description=f"[green]Computing features for {sample}",
        total=len(dataloader),
    ):
        features = _dino_features(x, model, batch_size)

        data = {}
        with h5py.File(tomo_dir / records[i]) as fh:
            for key in fh:
                data[key] = fh[key][()]  # type: ignore

        _save_data(data, features, records[i], result_dir)
        # Save PCA calculation of features
        if image_dir is not None:
            export_pca(data["data"], features.astype(np.float32), records[i][:-4], image_dir / sample)  # type: ignore


## For Scripts


[docs] def run_dino( train_data: list[Path], result_dir: Path, batch_size: int, window_size: int | None = DEFAULT_WINDOW_SIZE, visualize: bool = False, ) -> None: """Run DINO feature extraction on the specified training data, and saves the results as .hdf files. The saved result file will contain `data`, `dino_features`, and any labels present in the source tomogram in the `labels/` group. Args: train_data (list[Path]): List of paths to the training tomograms. result_dir (Path): Directory where the results will be saved. batch_size (int): Number of samples to process in each batch. window_size (Optional[int], optional): Size of the sliding window for feature extraction. If None, uses the default size. visualize (bool, optional): Whether to visualize the extracted features. Defaults to False. """ ## Setup hydra config with initialize( version_base="1.2", config_path="../configs", job_name="cryovit_features", ): cfg = compose( config_name="dino_features", overrides=[ f"batch_size={batch_size}", "sample=null", "export_features=False", "datamodule/dataset=file", ], ) cfg.paths.model_dir = Path(__file__).parent.parent / "foundation_models" ## Load DINOv2 model torch.hub.set_dir(cfg.dino_dir) model = torch.hub.load(*dino_model, verbose=False).cuda() # type: ignore model.eval() ## Setup dataset assert ( len(train_data) > 0 ), "No valid tomogram files found in the specified training data path." train_file_datas = [FileData(tomo_path=f) for f in train_data] dataset = instantiate( cfg.datamodule.dataset, input_key=None, label_key=None )(train_file_datas, for_dino=True) dataloader = instantiate(cfg.datamodule.dataloader)(dataset=dataset) result_list = [result_dir / f"{f.stem}.hdf" for f in train_data] ## Iterate through dataloader and extract features try: for i, x in track( enumerate(dataloader), description="[green]Computing DINO features for training data", total=len(dataloader), ): features = _dino_features( x.data, model, cfg.batch_size, window_size=window_size ) result_path = result_list[i].with_suffix(".hdf") result_path.parent.mkdir(parents=True, exist_ok=True) result_dir, tomo_name = result_path.parent, result_path.name _save_data(x.aux_data, features, tomo_name, result_dir) if visualize: export_pca( x.aux_data["data"], features, tomo_name[:-4], result_dir.parent / "dino_images" / tomo_name[:-4], ) except torch.OutOfMemoryError: print( f"Ran out of GPU memory during DINO feature extraction. Try reducing the batch size or window size. Current batch size is {cfg.batch_size} and window size is {window_size}." ) return
## For Experiments def run_trainer(cfg: DinoFeaturesConfig) -> None: """Sets up and executes the DINO feature computation using the specified configuration. Args: cfg (DinoFeaturesConfig): Configuration object containing all settings for the DINO feature computation. """ # Convert paths to Paths cfg.paths.model_dir = Path(cfg.paths.model_dir) cfg.paths.data_dir = Path(cfg.paths.data_dir) cfg.paths.exp_dir = Path(cfg.paths.exp_dir) src_dir = cfg.paths.data_dir / cfg.paths.tomo_name dst_dir = cfg.paths.data_dir / cfg.paths.feature_name csv_dir = cfg.paths.data_dir / cfg.paths.csv_name image_dir = cfg.paths.exp_dir / "dino_images" sample_names = ( [cfg.sample.name] if cfg.sample is not None else [s for s in samples if (src_dir / s).exists()] ) torch.hub.set_dir(cfg.dino_dir) model = torch.hub.load(*dino_model, verbose=False).cuda() # type: ignore model.eval() for sample_name in sample_names: _process_sample( src_dir, dst_dir, csv_dir, model, sample_name, cfg.datamodule, # type: ignore cfg.batch_size, image_dir if cfg.export_features else None, )