import numpy as np
import torch
import torch.utils.data as Data
import torchvision.transforms as transforms
from PIL import Image


class simple_dataset(Data.Dataset):
    def __init__(self, X: torch.Tensor, Y: torch.Tensor, transform=None):
        self.X = X
        self.Y = Y
        self.transform = transform
        if isinstance(X, torch.Tensor):
            self.flag = False
        else:
            self.flag = True

    def __getitem__(self, index: int):
        if self.flag:
            X = Image.fromarray(self.X[index])
        else:
            X = self.X[index]
        if self.transform is not None:
            X = self.transform(X)
        Y = self.Y[index]
        return X, Y

    def __len__(self):
        return self.X.shape[0]


def tinyimagenet_dataset(data_root='./data'):
    data = np.load('%s/tiny200.npz' % data_root)
    trainX = data['trainX']
    trainY = torch.from_numpy(data['trainY'])
    valX = data['valX']
    valY = torch.from_numpy(data['valY'])

    transform_train = transforms.Compose([
        transforms.RandomCrop(64, padding=8),
        transforms.RandAugment(),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.RandomErasing()
    ])

    transform_test = transforms.ToTensor()

    trainset = simple_dataset(trainX, trainY, transform_train)
    testset = simple_dataset(valX, valY, transform_test)
    return trainset, testset


def DDPM_dataset(data_root='./data', num_classes=100):
    data = np.load(f'{data_root}/c{num_classes}_250k.npz')
    trainX = torch.from_numpy(data['image']).permute(0, 3, 1, 2)
    trainY = torch.from_numpy(data['label'])

    trainX = trainX.float().div_(255.)

    transform_train = transforms.Compose([
        transforms.RandomCrop(64, padding=4),
        transforms.RandomHorizontalFlip()
    ])

    trainset = simple_dataset(trainX, trainY, transform_train)
    return trainset
