Source code for cryovit.datamodules.fractional_sample_datamodule

"""Implementation of the fractional leave one out data module."""

import pandas as pd

from cryovit.datamodules.base_datamodule import BaseDataModule


[docs] class FractionalSampleDataModule(BaseDataModule): """Data module for fractional leave-one-out CryoVIT experiments."""
[docs] def __init__( self, sample: list[str], split_id: int | None, split_key: str | None, test_sample: list[str] | None = None, **kwargs, ) -> None: """Train on a fraction of tomograms and leave out one sample for evaluation. Args: sample (list[str]): The samples to train and test on. split_id (Optional[int]): The number of splits used for training. If None, defaults to all splits. split_key (str): The key used to select splits using split_id. test_sample (Optional[list[str]]): The sample to exclude from training and use for testing. """ super().__init__(**kwargs) # Validity checks assert ( test_sample is not None ), "Fractional sample `test_sample` cannot be None." assert ( len(test_sample) == 1 ), f"Fractional sample 'test_sample' should be a single string list. Got {test_sample} instead." self.sample = sample self.split_id = split_id self.split_key = split_key self.test_sample = test_sample
[docs] def train_df(self) -> pd.DataFrame: """Train tomograms: include a subset of all splits, leaving out one sample. Returns: pd.DataFrame: A dataframe specifying the train tomograms. """ assert self.record_df is not None if self.split_id is not None: training_splits = list(range(self.split_id)) else: training_splits = list(range(self.record_df[self.split_key].max())) return self.record_df[ (self.record_df[self.split_key].isin(training_splits)) & (self.record_df["sample"].isin(self.sample)) & ~self.record_df["sample"].isin(self.test_sample) ][["sample", "tomo_name"]]
[docs] def val_df(self) -> pd.DataFrame: """Validation tomograms: validate on tomograms from the held out sample. Returns: pd.DataFrame: A dataframe specifying the validation tomograms. """ assert self.record_df is not None if self.split_id is not None: training_splits = list(range(self.split_id)) keys = ["sample", "tomo_name", self.split_key] else: training_splits = list(range(self.record_df[self.split_key].max())) keys = ["sample", "tomo_name"] return self.record_df[ (self.record_df[self.split_key].isin(training_splits)) & (self.record_df["sample"].isin(self.test_sample)) ][keys]
[docs] def test_df(self) -> pd.DataFrame: """Test tomograms: test on tomograms from the held out sample. Returns: pd.DataFrame: A dataframe specifying the test tomograms. """ df = self.val_df() # replace split_id with split fraction if it exists if self.split_key in df.columns: df["split_id"] = self.split_id return df
[docs] def predict_df(self) -> pd.DataFrame: """Predict tomograms: predict on the specified samples. Returns: pd.DataFrame: A dataframe specifying the predict tomograms. """ assert self.record_df is not None return self.record_df[self.record_df["sample"].isin(self.sample)][ ["sample", "tomo_name"] ]