Source code for cryovit.run.eval_model

"""Script for evaluating CryoVIT models based on configuration files."""

import logging
from collections.abc import Iterable
from pathlib import Path

import torch
from hydra import compose, initialize
from hydra.utils import instantiate
from pytorch_lightning import Trainer, seed_everything

from cryovit.config import BaseExperimentConfig
from cryovit.models import create_sam_model_from_weights
from cryovit.utils import load_model

torch.set_float32_matmul_precision("high")

## For Scripts


[docs] def run_evaluation( test_data: list[Path], test_labels: list[Path], labels: list[str], model_path: Path, result_dir: Path, visualize: bool = True, ) -> Path: """Run evaluation on the specified test data and labels, saving result metrics as a .csv file. Args: test_data (list[Path]): List of paths to the test tomograms. test_labels (list[Path]): List of paths to the test labels. labels (list[str]): List of label names to evaluate. model_path (Path): Path to the trained model file. result_dir (Path): Directory where the evaluation results will be saved. visualize (bool, optional): Whether to visualize the evaluation results. Defaults to True. Returns: Path: Path to the evaluation results file. """ ## Get model information model, model_type, model_name, label_key = load_model(model_path) ## Setup hydra config with initialize( version_base="1.2", config_path="../configs", job_name="cryovit_eval", ): cfg = compose( config_name="eval_model", overrides=[ f"name={model_name}", f"label_key={label_key}", f"model={model_type.value}", "additional_keys=[data]", "datamodule=file", ], ) cfg.paths.model_dir = Path(__file__).parent.parent / "foundation_models" cfg.paths.results_dir = result_dir # Check input key if cfg.model.input_key != "dino_features": cfg.model.input_key = None # find available data instead ## Setup dataset dataset_fn = instantiate(cfg.datamodule.dataset) dataloader_fn = instantiate(cfg.datamodule.dataloader) datamodule = instantiate(cfg.datamodule, _convert_="all")( data_paths=test_data, data_labels=test_labels, labels=labels, val_path=None, val_labels=None, dataloader_fn=dataloader_fn, dataset_fn=dataset_fn, ) logging.info("Setup dataset.") ## Setup training callbacks = [instantiate(cb_cfg) for cb_cfg in cfg.callbacks.values()] # Remove pred_writer if visualize is False loggers = [ instantiate(lg_cfg) for lg_name, lg_cfg in cfg.logger.items() if (visualize or lg_name != "test_pred_writer") ] trainer = instantiate(cfg.trainer, callbacks=callbacks, logger=loggers) logging.info("Starting testing.") trainer.test(model, datamodule=datamodule) # Load and return metrics path metrics_path = cfg.paths.results_dir / "results" / f"{model_name}.csv" return metrics_path
## For Experiments def setup_exp_dir(cfg: BaseExperimentConfig) -> BaseExperimentConfig: """Setup the experiment directory structure.""" # Convert paths to Paths cfg.paths.model_dir = Path(cfg.paths.model_dir) cfg.paths.data_dir = Path(cfg.paths.data_dir) cfg.paths.exp_dir = Path(cfg.paths.exp_dir) cfg.paths.results_dir = Path(cfg.paths.results_dir) if not isinstance(cfg.datamodule.sample, str) and isinstance( cfg.datamodule.sample, Iterable ): sample = "_".join(sorted(cfg.datamodule.sample)) else: sample = cfg.datamodule.sample if not isinstance(cfg.datamodule.test_sample, str) and isinstance( cfg.datamodule.test_sample, Iterable ): test_sample = "_".join(sorted(cfg.datamodule.test_sample)) else: test_sample = cfg.datamodule.test_sample new_exp_dir = cfg.paths.exp_dir / cfg.name / sample if cfg.datamodule.split_id is not None: new_exp_dir = new_exp_dir / f"split_{cfg.datamodule.split_id}" # Append test sample only for fractional sample experiments (test sample is separated for validation) if ( cfg.datamodule._target_ == "cryovit.datamodules.FractionalSampleDataModule" and test_sample is not None ): new_exp_dir = new_exp_dir / f"{test_sample}" cfg.paths.results_dir.mkdir(parents=True, exist_ok=True) assert ( new_exp_dir.exists() ), f"Experiment directory {new_exp_dir} does not exist. Run training first." cfg.paths.exp_dir = new_exp_dir if cfg.ckpt_path is None: cfg.ckpt_path = new_exp_dir / "weights.pt" return cfg def run_trainer(cfg: BaseExperimentConfig) -> None: """Sets up and executes the model evaluation using the specified configuration. Args: cfg (EvalModelConfig): Configuration object containing all settings for the evaluation process. """ seed_everything(cfg.random_seed, workers=True) # Setup experiment directories cfg = setup_exp_dir(cfg) assert ( cfg.ckpt_path is not None and cfg.ckpt_path.exists() ), f"{cfg.paths.exp_dir} does not contain a checkpoint." # Setup dataset dataset_fn = instantiate(cfg.datamodule.dataset) dataloader_fn = instantiate(cfg.datamodule.dataloader) split_file = cfg.paths.data_dir / cfg.paths.csv_name / cfg.paths.split_name datamodule = instantiate(cfg.datamodule, _convert_="all")( split_file=split_file, dataloader_fn=dataloader_fn, dataset_fn=dataset_fn, ) logging.info("Setup dataset.") # Setup evaluation callbacks = [instantiate(cb_cfg) for cb_cfg in cfg.callbacks.values()] loggers = [instantiate(lg_cfg) for lg_cfg in cfg.logger.values()] trainer: Trainer = instantiate( cfg.trainer, callbacks=callbacks, logger=loggers ) logging.info("Setup trainer.") if cfg.model._target_ == "cryovit.models.sam2.SAM2": # Load SAM2 pre-trained models model = create_sam_model_from_weights( cfg.model, cfg.paths.model_dir / cfg.paths.sam_name ) else: model = instantiate(cfg.model) # Load model weights if cfg.ckpt_path.suffix == ".pt": model.load_state_dict(torch.load(cfg.ckpt_path)) elif cfg.ckpt_path.suffix == ".ckpt": model = model.load_from_checkpoint(cfg.ckpt_path) else: raise ValueError( f"Unsupported checkpoint format: {cfg.ckpt_path.suffix}. Use .pt or .ckpt files." ) logging.info("Setup model.") logging.info("Starting testing.") trainer.test(model, datamodule=datamodule)