import lightning as L
from torch.utils.data import DataLoader

# import submodules
from .climate import ERA5Dataset, SSTDataset
from .speedy_weather import (
    SpeedyWeatherDataset,
    SpeedyWeatherDiscreteDataset,
    ShallowWaterDiscreteDataset,
)

__all__ = [
    SSTDataset,
    ERA5Dataset,
    SpeedyWeatherDataset,
    SpeedyWeatherDiscreteDataset,
    ShallowWaterDiscreteDataset
]


# ---------------- pl DataModule ----------------
class LitDataModule(L.LightningDataModule):
    def __init__(self, dataset_class, batch_size: int = 1, num_workers=8, **kwargs):
        super().__init__()
        self.train_set = dataset_class(mode="train", **kwargs)
        self.val_set = dataset_class(mode="val", **kwargs)
        self.test_set = dataset_class(mode="test", **kwargs)
        self.num_workers = num_workers

        self.batch_size = batch_size

    def train_dataloader(self):
        if hasattr(self.train_set, "collate_fn"):
            collate_fn = self.train_set.collate_fn
        else:
            collate_fn = None
        train_loader = DataLoader(
            self.train_set,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,  # os.cpu_count(),
            drop_last=True,
            collate_fn=collate_fn,  # distinguish the cases where it is not defined
            pin_memory=True,
        )
        return train_loader

    def val_dataloader(self):
        val_loader = DataLoader(
            self.val_set,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,  # os.cpu_count(),
            drop_last=True,
            pin_memory=True,
        )
        return val_loader

    def test_dataloader(self):
        test_loader = DataLoader(
            self.test_set,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,  # ,
            drop_last=True,
            pin_memory=True,
        )
        return test_loader

    def predict_dataloader(self):
        pred_loader = DataLoader(
            self.val_set,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,  # os.cpu_count(),
            drop_last=True,
            pin_memory=True,
            collate_fn=self.val_set.predict_collate_fn,
        )
        return pred_loader
