"""Module defining base data loading functionality for CryoVIT experiments."""
from abc import ABC, abstractmethod
from collections.abc import Callable
from pathlib import Path
import pandas as pd
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from cryovit.datamodules.utils import collate_fn
[docs]
class BaseDataModule(LightningDataModule, ABC):
"""Base module defining common functions for creating data loaders for experiments."""
[docs]
def __init__(
self,
split_file: Path,
dataset_fn: Callable,
dataloader_fn: Callable,
**kwargs,
) -> None:
"""Initializes the BaseDataModule with dataset parameters, a dataloader function, and a path to the split file.
Args:
split_file (Path): The path to the .csv file containing data splits.
dataset_fn (Callable): Function to create a Dataset from a dataframe of records.
dataloader_fn (Callable): Function to create a DataLoader from a dataset.
"""
super().__init__()
self.dataset_fn = dataset_fn
self.dataloader_fn = dataloader_fn
self.split_file = (
split_file if isinstance(split_file, Path) else Path(split_file)
)
self.record_df = pd.read_csv(self.split_file)
[docs]
@abstractmethod
def train_df(self) -> pd.DataFrame:
"""Abstract method to generate train splits."""
raise NotImplementedError
[docs]
@abstractmethod
def val_df(self) -> pd.DataFrame:
"""Abstract method to generate validation splits."""
raise NotImplementedError
[docs]
@abstractmethod
def test_df(self) -> pd.DataFrame:
"""Abstract method to generate test splits."""
raise NotImplementedError
[docs]
@abstractmethod
def predict_df(self) -> pd.DataFrame:
"""Abstract method to generate predict splits."""
raise NotImplementedError
[docs]
def train_dataloader(self) -> DataLoader:
"""Creates DataLoader for training data.
Returns:
DataLoader: A DataLoader instance for training data.
"""
records = self.train_df()
if records.empty:
raise ValueError(
"No training data found in the provided split file."
)
dataset = self.dataset_fn(records, train=True)
return self.dataloader_fn(dataset, shuffle=True, collate_fn=collate_fn)
[docs]
def val_dataloader(self) -> DataLoader:
"""Creates DataLoader for validation data.
Returns:
DataLoader: A DataLoader instance for validation data.
"""
records = self.val_df()
if records.empty:
raise ValueError(
"No validation data found in the provided split file."
)
dataset = self.dataset_fn(records, train=False)
return self.dataloader_fn(
dataset, shuffle=False, collate_fn=collate_fn
)
[docs]
def test_dataloader(self) -> DataLoader:
"""Creates DataLoader for testing data.
Returns:
DataLoader: A DataLoader instance for testing data.
"""
records = self.test_df()
if records.empty:
raise ValueError(
"No testing data found in the provided split file."
)
dataset = self.dataset_fn(records, train=False)
return self.dataloader_fn(
dataset, shuffle=False, collate_fn=collate_fn
)
[docs]
def predict_dataloader(self) -> DataLoader:
"""Creates DataLoader for prediction data.
Returns:
DataLoader: A DataLoader instance for prediction data.
"""
records = self.predict_df()
if records.empty:
raise ValueError(
"No prediction data found in the provided split file."
)
dataset = self.dataset_fn(records, train=False)
return self.dataloader_fn(
dataset, shuffle=False, collate_fn=collate_fn
)