import torch
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets.mnist as mnist
from antgine.dataset import AbstractDataset

#TODO temporary
_default_train_transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Lambda(lambda a: torch.cat([a, a, a], dim=0)),
    # https://github.com/kuangliu/pytorch-cifar/blob/bf78d3b8b358c4be7a25f9f9438c842d837801fd/main.py#L35
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

_default_test_transform = _default_train_transform


class MNIST(AbstractDataset):
    """
        MNIST dataset class.
    """
    def __init__(self, root: str, batch_size: int,
                 train_transform: transforms.Compose = _default_train_transform,
                 test_transform: transforms.Compose = _default_test_transform,
                 num_workers=8):
        """
        :param str root: Dataset's root directory.
        :param int batch_size: Batch size.
        :param transforms.Compose train_transform: Transform applied to inputs during training.
        :param transforms.Compose test_transform: Transform applied to inputs during testing.
        :param int num_workers: Number of workers launched for loading data.
        """
        super().__init__()
        self._root = root
        self._batch_size = batch_size
        self._train_transform = train_transform
        self._test_transform = test_transform
        self._num_workers = num_workers
        self._train_set = mnist.MNIST(self.root, train=True, transform=self.train_transform, download=True)
        self._test_set = mnist.MNIST(self.root, train=False, transform=self.test_transform, download=True)

    @property
    def root(self) -> str:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._root

    @property
    def batch_size(self) -> int:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._batch_size

    @property
    def train_transform(self) -> transforms.Compose:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._train_transform

    @property
    def test_transform(self) -> transforms.Compose:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._test_transform

    @property
    def train_set(self) -> data.Dataset:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._train_set

    @property
    def test_set(self) -> data.Dataset:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._test_set

    def _loader(self, dataset: data.Dataset, shuffle: bool):
        return data.DataLoader(dataset, batch_size=self.batch_size,
                               shuffle=shuffle, num_workers=self._num_workers)

    def train_loader(self, shuffle=True) -> data.DataLoader:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._loader(self.train_set, shuffle)

    def test_loader(self, shuffle=False) -> data.DataLoader:
        """
            See :meth:`antgine.dataset.AbstractDataset`.
        """
        return self._loader(self.test_set, shuffle)
