Source code for cryovit.datamodules.single_sample_datamodule

"""Implementation of the single sample data module."""

import pandas as pd

from cryovit.datamodules.base_datamodule import BaseDataModule


[docs] class SingleSampleDataModule(BaseDataModule): """Data module for CryoVIT experiments involving a single sample."""
[docs] def __init__( self, sample: list[str], split_id: int | None, split_key: str, test_sample: list[str] | None = None, **kwargs, ) -> None: """Create a datamodule for training and testing on a single sample. Args: sample (list[str]): The sample to train on. split_id (Optional[int]): An optional split_id to validate with. split_key (str): The key used to select splits using split_id. test_sample (Optional[list[str]]): The sample to test on. If None, test on the validation set. """ super().__init__(**kwargs) # Validity checks assert ( len(sample) == 1 ), f"Single sample 'sample' should be a single string list. Got {sample} instead." assert ( test_sample is None or len(test_sample) == 1 ), f"Single sample 'test_sample' should be a single string list or None. Got {test_sample} instead." self.sample = sample[0] self.split_id = split_id self.split_key = split_key self.test_sample = ( test_sample[0] if test_sample is not None else test_sample )
[docs] def train_df(self) -> pd.DataFrame: """Train tomograms: exclude those from the sample with the specified split_id. Returns: pd.DataFrame: A dataframe specifying the train tomograms. """ assert self.record_df is not None if self.split_id is not None: df = self.record_df[ (self.record_df[self.split_key] != self.split_id) & (self.record_df["sample"] == self.sample) ] else: df = self.record_df[self.record_df["sample"] == self.sample][ ["sample", "tomo_name"] ] return df
[docs] def val_df(self) -> pd.DataFrame: """Validation tomograms: optionally validate on tomograms with the specified split_id. Returns: pd.DataFrame: A dataframe specifying the validation tomograms. """ assert self.record_df is not None if self.split_id is None: # validate on train set return self.train_df() return self.record_df[ (self.record_df[self.split_key] == self.split_id) & (self.record_df["sample"] == self.sample) ]
[docs] def test_df(self) -> pd.DataFrame: """Test tomograms: test on tomograms from the specified test_sample or split_id. Returns: pd.DataFrame: A dataframe specifying the test tomograms. """ assert self.record_df is not None if self.test_sample is None: return self.val_df() # If testing on another sample, use the whole sample return self.record_df[self.record_df["sample"] == self.test_sample][ ["sample", "tomo_name"] ]
[docs] def predict_df(self) -> pd.DataFrame: """Predict tomograms: predict on the whole sample. Returns: pd.DataFrame: A dataframe specifying the predict tomograms. """ assert self.record_df is not None return self.record_df[self.record_df["sample"] == self.sample][ ["sample", "tomo_name"] ]