Source code for cryovit.run.train_model

"""Script for setting up and training CryoVIT models based on configuration files."""

import logging
from collections.abc import Iterable
from pathlib import Path

import torch
from hydra import compose, initialize
from hydra.core.hydra_config import HydraConfig
from hydra.utils import instantiate
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import TensorBoardLogger

from cryovit.config import BaseExperimentConfig
from cryovit.models import create_sam_model_from_weights
from cryovit.types import ModelType
from cryovit.utils import save_model

torch.set_float32_matmul_precision("high")

## For Scripts


[docs] def run_training( train_data: list[Path], train_labels: list[Path], labels: list[str], model_type: ModelType, model_name: str, label_key: str, result_dir: Path, val_data: list[Path] | None = None, val_labels: list[Path] | None = None, num_epochs: int = 50, log_training: bool = False, ) -> Path: """Run training on the specified data and labels. Args: train_data (list[Path]): List of paths to the training tomograms. train_labels (list[Path]): List of paths to the training labels. labels (list[str]): List of label names to train on. model_type (ModelType): Type of the model to train. model_name (str): Name of the model. label_key (str): Key for the label in the dataset. result_dir (Path): Directory where the training results will be saved. val_data (Optional[list[Path]], optional): List of paths to the validation tomograms. Defaults to None. val_labels (Optional[list[Path]], optional): List of paths to the validation labels. Defaults to None. num_epochs (int, optional): Number of training epochs. Defaults to 50. log_training (bool, optional): Whether to log training metrics to Tensorboard. Defaults to False. Returns: Path: Path to the saved model file. """ ## Setup hydra config with initialize( version_base="1.2", config_path="../configs", job_name="cryovit_train", ): cfg = compose( config_name="train_model", overrides=[ f"name={model_name}", f"label_key={label_key}", f"model={model_type.value}", "datamodule=file", f"trainer.max_epochs={num_epochs}", ], ) cfg.paths.model_dir = Path(__file__).parent.parent / "foundation_models" save_model_path = result_dir / f"{model_name}.model" # Check input key if cfg.model.input_key != "dino_features": cfg.model.input_key = None # find available data instead ## Setup dataset dataset_fn = instantiate(cfg.datamodule.dataset) dataloader_fn = instantiate(cfg.datamodule.dataloader) datamodule = instantiate(cfg.datamodule, _convert_="all")( data_paths=train_data, data_labels=train_labels, labels=labels, val_paths=val_data, val_labels=val_labels, dataloader_fn=dataloader_fn, dataset_fn=dataset_fn, ) logging.info("Setup dataset.") ## Setup training callbacks = [instantiate(cb_cfg) for cb_cfg in cfg.callbacks.values()] loggers = [] if log_training: # tensorboard logger to avoid wandb account issues loggers.append(TensorBoardLogger(save_dir=result_dir, name=model_name)) logging.info( "Setup TensorBoard logger. View logs with `tensorboard --logdir %s`", result_dir / model_name, ) trainer = instantiate(cfg.trainer, callbacks=callbacks, logger=loggers) if cfg.model._target_ == "cryovit.models.sam2.SAM2": # Load SAM2 pre-trained models model = create_sam_model_from_weights( cfg.model, cfg.paths.model_dir / cfg.paths.sam_name ) else: model = instantiate(cfg.model) logging.info("Loaded model.") # Base SAM2 only supports image encoder compilation if cfg.model._target_ == "cryovit.models.sam2.SAM2": logging.info("Compiling image encoder for SAM2 model.") try: model.compile() except Exception as e: # noqa: BLE001 logging.error("Unable to compile image encoder for SAM2: %s", e) else: logging.info("Compiling model forward pass.") try: model.forward = torch.compile(model.forward) except Exception as e: # noqa: BLE001 logging.error("Unable to compile forward pass: %s", e) logging.info("Starting training.") trainer.fit(model, datamodule=datamodule) # Save model logging.info("Saving model.") save_model(model_name, label_key, model, cfg.model, save_model_path) return save_model_path
## For Experiments def setup_exp_dir(cfg: BaseExperimentConfig) -> BaseExperimentConfig: """Setup the experiment directory structure and optionally, the W&B logger.""" # Convert paths to Paths cfg.paths.model_dir = Path(cfg.paths.model_dir) cfg.paths.data_dir = Path(cfg.paths.data_dir) cfg.paths.exp_dir = Path(cfg.paths.exp_dir) cfg.paths.results_dir = Path(cfg.paths.results_dir) if not isinstance(cfg.datamodule.sample, str) and isinstance( cfg.datamodule.sample, Iterable ): sample = "_".join(sorted(cfg.datamodule.sample)) else: sample = cfg.datamodule.sample if not isinstance(cfg.datamodule.test_sample, str) and isinstance( cfg.datamodule.test_sample, Iterable ): test_sample = "_".join(sorted(cfg.datamodule.test_sample)) else: test_sample = cfg.datamodule.test_sample new_exp_dir = cfg.paths.exp_dir / cfg.name / sample new_exp_dir.mkdir(parents=True, exist_ok=True) if cfg.datamodule.split_id is not None: new_exp_dir = new_exp_dir / f"split_{cfg.datamodule.split_id}" if ( cfg.datamodule._target_ == "cryovit.datamodules.FractionalSampleDataModule" and test_sample is not None ): new_exp_dir = new_exp_dir / f"{test_sample}" new_exp_dir.mkdir(parents=True, exist_ok=True) cfg.paths.exp_dir = new_exp_dir # Setup WandB Logger for name, lg in cfg.logger.items(): if name == "wandb": lg.name = ( f"{test_sample or sample}_{cfg.datamodule.split_id}" if cfg.datamodule.split_id is not None else (test_sample or sample) ) return cfg def run_trainer(cfg: BaseExperimentConfig) -> None: """Sets up and runs the training process using the specified configuration. Args: cfg (TrainModelConfig): Configuration object containing all settings for the training process. """ seed_everything(cfg.random_seed, workers=True) # Setup experiment directories cfg = setup_exp_dir(cfg) ckpt_path = ( cfg.paths.exp_dir / "last.ckpt" if cfg.ckpt_path is None else cfg.ckpt_path ) weights_path = cfg.paths.exp_dir / "weights.pt" # Setup dataset dataset_fn = instantiate(cfg.datamodule.dataset) dataloader_fn = instantiate(cfg.datamodule.dataloader) split_file = cfg.paths.data_dir / cfg.paths.csv_name / cfg.paths.split_name datamodule = instantiate(cfg.datamodule, _convert_="all")( split_file=split_file, dataloader_fn=dataloader_fn, dataset_fn=dataset_fn, ) logging.info("Setup dataset.") # Setup training callbacks = [instantiate(cb_cfg) for cb_cfg in cfg.callbacks.values()] loggers = [instantiate(lg_cfg) for lg_cfg in cfg.logger.values()] trainer: Trainer = instantiate( cfg.trainer, callbacks=callbacks, logger=loggers ) logging.info("Setup trainer.") if cfg.model._target_ == "cryovit.models.sam2.SAM2": # Load SAM2 pre-trained models model = create_sam_model_from_weights( cfg.model, cfg.paths.model_dir / cfg.paths.sam_name ) else: model = instantiate(cfg.model) logging.info("Setup model.") # Log hyperparameters if trainer.loggers: hparams = { "datamodule_type": HydraConfig.get().runtime.choices["datamodule"], "model_name": cfg.model.name, "label_key": cfg.label_key, "experiment": cfg.name, "split_id": cfg.datamodule.split_id, "sample": ( "_".join(sorted(cfg.datamodule.sample)) if isinstance(cfg.datamodule.sample, Iterable) else cfg.datamodule.sample ), "test_sample": cfg.datamodule.test_sample, "cfg": cfg, "model": model, "model/params/total": sum(p.numel() for p in model.parameters()), "model/params/trainable": sum( p.numel() for p in model.parameters() if p.requires_grad ), "model/params/non_trainable": sum( p.numel() for p in model.parameters() if not p.requires_grad ), "datamodule": datamodule, "trainer": trainer, "resume_ckpt": cfg.resume_ckpt, "ckpt_path": cfg.ckpt_path, "seed": cfg.random_seed, } if cfg.model._target_ == "cryovit.models.sam2.SAM2": hparams["prompt_lr"] = ( cfg.model.custom_kwargs.get("prompt_lr", None) if cfg.model.custom_kwargs else None ) for lg in trainer.loggers: lg.log_hyperparams(hparams) # Base SAM2 only supports image encoder compilation if cfg.model._target_ == "cryovit.models.sam2.SAM2": logging.info("Compiling image encoder for SAM2 model.") try: model.compile() except Exception as e: # noqa: BLE001 logging.warning("Unable to compile image encoder for SAM2: %s", e) else: logging.info("Compiling model forward pass.") try: model.forward = torch.compile(model.forward) except Exception as e: # noqa: BLE001 logging.warning("Unable to compile forward pass: %s", e) logging.info("Starting training.") if cfg.resume_ckpt and ckpt_path.exists(): logging.info("Resuming training from checkpoint: %s", ckpt_path) trainer.fit(model, datamodule=datamodule, ckpt_path=str(ckpt_path)) else: trainer.fit(model, datamodule=datamodule) # Save model logging.info("Saving model.") torch.save(model.state_dict(), weights_path)