from typing import Any, Dict, Optional, Tuple

import torch
from lightning import LightningDataModule
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
from torchvision.datasets import MNIST
from .pub_pred_dataset import *
from torchvision.transforms import transforms
import os
from .dataloaders import DataLoaders


username = os.getlogin()

D2PATH = {'ETTm1': f'/data/{username}/public_data/ETDataset/ETT-small/',
         'ETTm2': f'/data/{username}/public_data/ETDataset/ETT-small/',
         'ETTh1': f'/data/{username}/public_data/ETDataset/ETT-small/',
         'ETTh2': f'/data/{username}/public_data/ETDataset/ETT-small/',
         'electricity': f'/data/{username}/public_data/electricity/',
         'traffic': f'/data/{username}/public_data/traffic/',
         'national_illness': f'/data/{username}/public_data/illness/',
         'weather': f'/data/{username}/public_data/weather/',
         'exchange_rate': f'/data/{username}/public_data/exchange_rate/'
         }

D2SET = {'ETTm1': Dataset_ETT_minute,
         'ETTm2': Dataset_ETT_minute,
         'ETTh1': Dataset_ETT_hour,
         'ETTh2': Dataset_ETT_hour,
         'electricity': Dataset_Custom,
         'traffic': Dataset_Custom,
         'national_illness': Dataset_Custom,
         'weather': Dataset_Custom,
         'exchange_rate': Dataset_Custom
         }


class PubPredDataModule(LightningDataModule):

    def __init__(
        self,
        params
    ) -> None:
        """Initialize a `MNISTDataModule`.

        :param data_dir: The data directory. Defaults to `"data/"`.
        :param train_val_test_split: The train, validation and test split. Defaults to `(55_000, 5_000, 10_000)`.
        :param batch_size: The batch size. Defaults to `64`.
        :param num_workers: The number of workers. Defaults to `0`.
        :param pin_memory: Whether to pin memory. Defaults to `False`.
        """
        print(params)
        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.dls: DataLoaders = None
        self.params = params
        self.save_hyperparameters(logger=False)

        self.data_train: Optional[Dataset] = None
        self.data_val: Optional[Dataset] = None
        self.data_test: Optional[Dataset] = None

        root_path = D2PATH[self.params.dset]
        size = [self.params.context_points, 0, self.params.target_points]
        dls = DataLoaders(
            datasetCls=D2SET[self.params.dset],
            dataset_kwargs={
                'root_path': root_path,
                'data_path': f'{self.params.dset}.csv',
                'features': self.params.features,
                'scale': True,
                'size': size,
                'use_time_features': self.params.use_time_features
            },
            batch_size=self.params.batch_size,
            workers=self.params.num_workers,
        )
        dls.vars, dls.len = dls.train.dataset[0][0].shape[1], params.context_points
        dls.c = dls.train.dataset[0][1].shape[0]
        self.dls = dls


    def prepare_data(self) -> None:
        """Download data if needed. Lightning ensures that `self.prepare_data()` is called only
        within a single process on CPU, so you can safely add your downloading logic within. In
        case of multi-node training, the execution of this hook depends upon
        `self.prepare_data_per_node()`.

        Do not use it to assign state (self.x = y).
        """
        ...

    def setup(self, stage: Optional[str] = None) -> None:
        """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.

        This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and
        `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after
        `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to
        `self.setup()` once the data is prepared and available for use.

        :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``.
        """
        # load and split datasets only if not loaded already
        ...

    def train_dataloader(self) -> DataLoader[Any]:
        """Create and return the train dataloader.

        :return: The train dataloader.
        """
        return self.dls.train_dataloader()

    def val_dataloader(self) -> DataLoader[Any]:
        """Create and return the validation dataloader.

        :return: The validation dataloader.
        """
        return self.dls.val_dataloader()

    def test_dataloader(self) -> DataLoader[Any]:
        """Create and return the test dataloader.

        :return: The test dataloader.
        """
        return self.dls.test_dataloader()

    def teardown(self, stage: Optional[str] = None) -> None:
        """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`,
        `trainer.test()`, and `trainer.predict()`.

        :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
            Defaults to ``None``.
        """
        pass

    def state_dict(self) -> Dict[Any, Any]:
        """Called when saving a checkpoint. Implement to generate and save the datamodule state.

        :return: A dictionary containing the datamodule state that you want to save.
        """
        return {}

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        """Called when loading a checkpoint. Implement to reload datamodule state given datamodule
        `state_dict()`.

        :param state_dict: The datamodule state returned by `self.state_dict()`.
        """
        pass