cryovit.datamodules
Implementations of PyTorch Lightning DataModules for loading Cryo-EM tomograms.
Classes
|
Base module defining common functions for creating data loaders for experiments. |
|
Data module for fractional leave-one-out CryoVIT experiments. |
|
Data module for CryoVIT experiments involving a single sample. |
|
Data module for CryoVIT experiments involving multiple samples. |
|
Module defining common functions for creating data loaders for file-based datasets. |
- class BaseDataModule(split_file: Path, dataset_fn: Callable, dataloader_fn: Callable, **kwargs)[source]
Bases:
LightningDataModule,ABCBase module defining common functions for creating data loaders for experiments.
- __init__(split_file: Path, dataset_fn: Callable, dataloader_fn: Callable, **kwargs) None[source]
Initializes the BaseDataModule with dataset parameters, a dataloader function, and a path to the split file.
- Parameters:
split_file (Path) – The path to the .csv file containing data splits.
dataset_fn (Callable) – Function to create a Dataset from a dataframe of records.
dataloader_fn (Callable) – Function to create a DataLoader from a dataset.
- train_dataloader() DataLoader[source]
Creates DataLoader for training data.
- Returns:
A DataLoader instance for training data.
- Return type:
DataLoader
- val_dataloader() DataLoader[source]
Creates DataLoader for validation data.
- Returns:
A DataLoader instance for validation data.
- Return type:
DataLoader
- class FractionalSampleDataModule(sample: list[str], split_id: int | None, split_key: str | None, test_sample: list[str] | None = None, **kwargs)[source]
Bases:
BaseDataModuleData module for fractional leave-one-out CryoVIT experiments.
- __init__(sample: list[str], split_id: int | None, split_key: str | None, test_sample: list[str] | None = None, **kwargs) None[source]
Train on a fraction of tomograms and leave out one sample for evaluation.
- Parameters:
sample (list[str]) – The samples to train and test on.
split_id (Optional[int]) – The number of splits used for training. If None, defaults to all splits.
split_key (str) – The key used to select splits using split_id.
test_sample (Optional[list[str]]) – The sample to exclude from training and use for testing.
- train_df() DataFrame[source]
Train tomograms: include a subset of all splits, leaving out one sample.
- Returns:
A dataframe specifying the train tomograms.
- Return type:
pd.DataFrame
- val_df() DataFrame[source]
Validation tomograms: validate on tomograms from the held out sample.
- Returns:
A dataframe specifying the validation tomograms.
- Return type:
pd.DataFrame
- class SingleSampleDataModule(sample: list[str], split_id: int | None, split_key: str, test_sample: list[str] | None = None, **kwargs)[source]
Bases:
BaseDataModuleData module for CryoVIT experiments involving a single sample.
- __init__(sample: list[str], split_id: int | None, split_key: str, test_sample: list[str] | None = None, **kwargs) None[source]
Create a datamodule for training and testing on a single sample.
- Parameters:
sample (list[str]) – The sample to train on.
split_id (Optional[int]) – An optional split_id to validate with.
split_key (str) – The key used to select splits using split_id.
test_sample (Optional[list[str]]) – The sample to test on. If None, test on the validation set.
- train_df() DataFrame[source]
Train tomograms: exclude those from the sample with the specified split_id.
- Returns:
A dataframe specifying the train tomograms.
- Return type:
pd.DataFrame
- val_df() DataFrame[source]
Validation tomograms: optionally validate on tomograms with the specified split_id.
- Returns:
A dataframe specifying the validation tomograms.
- Return type:
pd.DataFrame
- class MultiSampleDataModule(sample: list[str], split_id: int | None, split_key: str | None, test_sample: list[str] | None = None, **kwargs)[source]
Bases:
BaseDataModuleData module for CryoVIT experiments involving multiple samples.
- __init__(sample: list[str], split_id: int | None, split_key: str | None, test_sample: list[str] | None = None, **kwargs) None[source]
Train on a fraction of tomograms and leave out one sample for evaluation.
- Parameters:
sample (list[str]) – list of samples used for training.
split_id (Optional[int]) – An optional split ID for validation.
split_key (str) – The key used to select splits using split_id.
test_sample (Optional[list[str]]) – list of samples used for testing. If None, test on the validation set.
- train_df() DataFrame[source]
Train tomograms: exclude those with the specified split_id.
- Returns:
A dataframe specifying the train tomograms.
- Return type:
pd.DataFrame
- val_df() DataFrame[source]
Validation tomograms: optionally validate on tomograms with the specified split_id.
- Returns:
A dataframe specifying the validation tomograms.
- Return type:
pd.DataFrame
- class FileDataModule(data_paths: list[Path], dataset_fn: Callable, dataloader_fn: Callable, val_paths: list[Path] | None = None, data_labels: list[Path] | None = None, val_labels: list[Path] | None = None, labels: list[str] | None = None, **kwargs)[source]
Bases:
LightningDataModuleModule defining common functions for creating data loaders for file-based datasets.
- __init__(data_paths: list[Path], dataset_fn: Callable, dataloader_fn: Callable, val_paths: list[Path] | None = None, data_labels: list[Path] | None = None, val_labels: list[Path] | None = None, labels: list[str] | None = None, **kwargs) None[source]
Initializes the BaseDataModule with dataset parameters, a dataloader function, and a path to the split file.
- Parameters:
data_paths (list[Path]) – A list of paths to the data files for training/testing/prediction.
dataset_fn (Callable) – Function to create a Dataset from a list of FileData objects.
dataloader_fn (Callable) – Function to create a DataLoader from a dataset.
val_paths (Optional[list[Path]]) – A list of paths to the data files for validation.
data_labels (Optional[list[Path]]) – A list of paths to the label files for training/testing/prediction. Should only be missing for inference.
val_labels (Optional[list[Path]]) – A list of paths to the label files for validation.
labels (Optional[list[str]]) – A list of label keys to load from the label files. Should only be missing if no labels are provided.
- train_dataloader() DataLoader[source]
Creates DataLoader for training data.
- Returns:
A DataLoader instance for training data.
- Return type:
DataLoader
- val_dataloader() DataLoader[source]
Creates DataLoader for validation data.
- Returns:
A DataLoader instance for validation data.
- Return type:
DataLoader