cryovit.datamodules

Implementations of PyTorch Lightning DataModules for loading Cryo-EM tomograms.

Classes

BaseDataModule(split_file, dataset_fn, ...)

Base module defining common functions for creating data loaders for experiments.

FractionalSampleDataModule(sample, split_id, ...)

Data module for fractional leave-one-out CryoVIT experiments.

SingleSampleDataModule(sample, split_id, ...)

Data module for CryoVIT experiments involving a single sample.

MultiSampleDataModule(sample, split_id, ...)

Data module for CryoVIT experiments involving multiple samples.

FileDataModule(data_paths, dataset_fn, ...)

Module defining common functions for creating data loaders for file-based datasets.

class BaseDataModule(split_file: Path, dataset_fn: Callable, dataloader_fn: Callable, **kwargs)[source]

Bases: LightningDataModule, ABC

Base module defining common functions for creating data loaders for experiments.

__init__(split_file: Path, dataset_fn: Callable, dataloader_fn: Callable, **kwargs) None[source]

Initializes the BaseDataModule with dataset parameters, a dataloader function, and a path to the split file.

Parameters:
  • split_file (Path) – The path to the .csv file containing data splits.

  • dataset_fn (Callable) – Function to create a Dataset from a dataframe of records.

  • dataloader_fn (Callable) – Function to create a DataLoader from a dataset.

abstract train_df() DataFrame[source]

Abstract method to generate train splits.

abstract val_df() DataFrame[source]

Abstract method to generate validation splits.

abstract test_df() DataFrame[source]

Abstract method to generate test splits.

abstract predict_df() DataFrame[source]

Abstract method to generate predict splits.

train_dataloader() DataLoader[source]

Creates DataLoader for training data.

Returns:

A DataLoader instance for training data.

Return type:

DataLoader

val_dataloader() DataLoader[source]

Creates DataLoader for validation data.

Returns:

A DataLoader instance for validation data.

Return type:

DataLoader

test_dataloader() DataLoader[source]

Creates DataLoader for testing data.

Returns:

A DataLoader instance for testing data.

Return type:

DataLoader

predict_dataloader() DataLoader[source]

Creates DataLoader for prediction data.

Returns:

A DataLoader instance for prediction data.

Return type:

DataLoader

class FractionalSampleDataModule(sample: list[str], split_id: int | None, split_key: str | None, test_sample: list[str] | None = None, **kwargs)[source]

Bases: BaseDataModule

Data module for fractional leave-one-out CryoVIT experiments.

__init__(sample: list[str], split_id: int | None, split_key: str | None, test_sample: list[str] | None = None, **kwargs) None[source]

Train on a fraction of tomograms and leave out one sample for evaluation.

Parameters:
  • sample (list[str]) – The samples to train and test on.

  • split_id (Optional[int]) – The number of splits used for training. If None, defaults to all splits.

  • split_key (str) – The key used to select splits using split_id.

  • test_sample (Optional[list[str]]) – The sample to exclude from training and use for testing.

train_df() DataFrame[source]

Train tomograms: include a subset of all splits, leaving out one sample.

Returns:

A dataframe specifying the train tomograms.

Return type:

pd.DataFrame

val_df() DataFrame[source]

Validation tomograms: validate on tomograms from the held out sample.

Returns:

A dataframe specifying the validation tomograms.

Return type:

pd.DataFrame

test_df() DataFrame[source]

Test tomograms: test on tomograms from the held out sample.

Returns:

A dataframe specifying the test tomograms.

Return type:

pd.DataFrame

predict_df() DataFrame[source]

Predict tomograms: predict on the specified samples.

Returns:

A dataframe specifying the predict tomograms.

Return type:

pd.DataFrame

class SingleSampleDataModule(sample: list[str], split_id: int | None, split_key: str, test_sample: list[str] | None = None, **kwargs)[source]

Bases: BaseDataModule

Data module for CryoVIT experiments involving a single sample.

__init__(sample: list[str], split_id: int | None, split_key: str, test_sample: list[str] | None = None, **kwargs) None[source]

Create a datamodule for training and testing on a single sample.

Parameters:
  • sample (list[str]) – The sample to train on.

  • split_id (Optional[int]) – An optional split_id to validate with.

  • split_key (str) – The key used to select splits using split_id.

  • test_sample (Optional[list[str]]) – The sample to test on. If None, test on the validation set.

train_df() DataFrame[source]

Train tomograms: exclude those from the sample with the specified split_id.

Returns:

A dataframe specifying the train tomograms.

Return type:

pd.DataFrame

val_df() DataFrame[source]

Validation tomograms: optionally validate on tomograms with the specified split_id.

Returns:

A dataframe specifying the validation tomograms.

Return type:

pd.DataFrame

test_df() DataFrame[source]

Test tomograms: test on tomograms from the specified test_sample or split_id.

Returns:

A dataframe specifying the test tomograms.

Return type:

pd.DataFrame

predict_df() DataFrame[source]

Predict tomograms: predict on the whole sample.

Returns:

A dataframe specifying the predict tomograms.

Return type:

pd.DataFrame

class MultiSampleDataModule(sample: list[str], split_id: int | None, split_key: str | None, test_sample: list[str] | None = None, **kwargs)[source]

Bases: BaseDataModule

Data module for CryoVIT experiments involving multiple samples.

__init__(sample: list[str], split_id: int | None, split_key: str | None, test_sample: list[str] | None = None, **kwargs) None[source]

Train on a fraction of tomograms and leave out one sample for evaluation.

Parameters:
  • sample (list[str]) – list of samples used for training.

  • split_id (Optional[int]) – An optional split ID for validation.

  • split_key (str) – The key used to select splits using split_id.

  • test_sample (Optional[list[str]]) – list of samples used for testing. If None, test on the validation set.

train_df() DataFrame[source]

Train tomograms: exclude those with the specified split_id.

Returns:

A dataframe specifying the train tomograms.

Return type:

pd.DataFrame

val_df() DataFrame[source]

Validation tomograms: optionally validate on tomograms with the specified split_id.

Returns:

A dataframe specifying the validation tomograms.

Return type:

pd.DataFrame

test_df() DataFrame[source]

Test tomograms: test on tomograms from the test samples.

Returns:

A dataframe specifying the test tomograms.

Return type:

pd.DataFrame

predict_df() DataFrame[source]

Predict tomograms: predict on the specified samples.

Returns:

A dataframe specifying the predict tomograms.

Return type:

pd.DataFrame

class FileDataModule(data_paths: list[Path], dataset_fn: Callable, dataloader_fn: Callable, val_paths: list[Path] | None = None, data_labels: list[Path] | None = None, val_labels: list[Path] | None = None, labels: list[str] | None = None, **kwargs)[source]

Bases: LightningDataModule

Module defining common functions for creating data loaders for file-based datasets.

__init__(data_paths: list[Path], dataset_fn: Callable, dataloader_fn: Callable, val_paths: list[Path] | None = None, data_labels: list[Path] | None = None, val_labels: list[Path] | None = None, labels: list[str] | None = None, **kwargs) None[source]

Initializes the BaseDataModule with dataset parameters, a dataloader function, and a path to the split file.

Parameters:
  • data_paths (list[Path]) – A list of paths to the data files for training/testing/prediction.

  • dataset_fn (Callable) – Function to create a Dataset from a list of FileData objects.

  • dataloader_fn (Callable) – Function to create a DataLoader from a dataset.

  • val_paths (Optional[list[Path]]) – A list of paths to the data files for validation.

  • data_labels (Optional[list[Path]]) – A list of paths to the label files for training/testing/prediction. Should only be missing for inference.

  • val_labels (Optional[list[Path]]) – A list of paths to the label files for validation.

  • labels (Optional[list[str]]) – A list of label keys to load from the label files. Should only be missing if no labels are provided.

train_dataloader() DataLoader[source]

Creates DataLoader for training data.

Returns:

A DataLoader instance for training data.

Return type:

DataLoader

val_dataloader() DataLoader[source]

Creates DataLoader for validation data.

Returns:

A DataLoader instance for validation data.

Return type:

DataLoader

test_dataloader() DataLoader[source]

Creates DataLoader for testing data.

Returns:

A DataLoader instance for testing data.

Return type:

DataLoader

predict_dataloader() DataLoader[source]

Creates DataLoader for prediction data.

Returns:

A DataLoader instance for prediction data.

Return type:

DataLoader