Source code for cryovit.run.infer_model

"""Script for evaluating CryoVIT models based on configuration files."""

import logging
from pathlib import Path

import torch
from hydra import compose, initialize
from hydra.utils import instantiate

from cryovit.models.callbacks import PredictionWriter
from cryovit.utils import load_model

torch.set_float32_matmul_precision("high")

## For Scripts


[docs] def run_inference( data_files: list[Path], model_path: Path, result_dir: Path, threshold: float = 0.5, ) -> list[Path]: """Run inference on the specified data files and saves the results. Args: data_files (list[Path]): List of paths to the input data files. model_path (Path): Path to the trained model file. result_dir (Path): Directory where the inference results will be saved. threshold (float, optional): Threshold for binary classification. Defaults to 0.5. Returns: list[Path]: List of paths to the saved result files. """ # Get model information model, model_type, model_name, label_key = load_model(model_path) assert model is not None, "Loaded model is None." ## Setup hydra config with initialize( version_base="1.2", config_path="../configs", job_name="infer_model", ): cfg = compose( config_name="infer_model", overrides=[ f"name={model_name}", f"label_key={label_key}", f"model={model_type.value}", "datamodule=file", ], ) cfg.paths.model_dir = Path(__file__).parent.parent / "foundation_models" cfg.paths.results_dir = result_dir # Check input key if cfg.model.input_key != "dino_features": cfg.model.input_key = None # find available data instead ## Setup dataset dataset_fn = instantiate(cfg.datamodule.dataset) dataloader_fn = instantiate(cfg.datamodule.dataloader) datamodule = instantiate(cfg.datamodule, _convert_="all")( data_paths=data_files, val_paths=None, dataloader_fn=dataloader_fn, dataset_fn=dataset_fn, ) logging.info("Setup dataset.") ## Setup training pred_writer = PredictionWriter( results_dir=result_dir, label_key=label_key, threshold=threshold ) callbacks = [instantiate(cb_cfg) for cb_cfg in cfg.callbacks.values()] callbacks.append(pred_writer) loggers = [] trainer = instantiate(cfg.trainer, callbacks=callbacks, logger=loggers) logging.info("Starting prediction.") trainer.predict(model, datamodule=datamodule) result_paths = pred_writer.result_paths return result_paths