import numpy as np
import torch.utils.data as Data

from .cifar import cifar_dataset
# from .imagenet_dataset import imagenet_dataset
# from .tinyimgnet import tinyimagenet_dataset

from .semisup import SemiSupDataset, SemiSupSampler


def data_loader(data_name: str = 'cifar',
                num_classes: int = 10,
                batch_size: int = 128,
                distributed: bool = False,
                data_root: str = './data/',
                auxiliary: str = None,
                fraction: float = 0.3,
                seed: int = 2023):

    if 'cifar' in data_name and num_classes in [10, 100]:
        train_dataset, test_dataset = cifar_dataset(num_classes, data_root=data_root)
    # elif 'tiny' in data_name and num_classes == 200:
    #     trainset, testset = tinyimagenet_dataset(data_root=data_root)
    # elif data_name == 'imagenet' and num_classes <= 1000:
    #     trainset, testset = imagenet_dataset(data_root=data_root,
    #                                          num_classes=num_classes,
    #                                          seed=seed)
    else:
        raise ValueError('The given dataset config is not supported!')

    if auxiliary:
        train_dataset = SemiSupDataset(train_dataset, auxiliary, train=True)
        test_dataset = SemiSupDataset(test_dataset, auxiliary, train=False)
        assert not distributed, 'Training with auxiliary samples is only supported on a single GPU'
        num_batches = int(np.ceil(train_dataset.dataset_size / batch_size))
        train_sampler = SemiSupSampler(train_dataset.sup_indices, train_dataset.unsup_indices, batch_size,
                                       unsup_fraction=fraction, num_batches=num_batches)
        train_loader = Data.DataLoader(train_dataset,
                                       batch_sampler=train_sampler,
                                       num_workers=8,
                                       pin_memory=True,
                                       persistent_workers=True)
    else:
        train_sampler = Data.distributed.DistributedSampler(train_dataset) if distributed else None
        train_loader = Data.DataLoader(train_dataset,
                                       batch_size=batch_size,
                                       sampler=train_sampler,
                                       num_workers=8,
                                       shuffle=(train_sampler is None),
                                       drop_last=True,
                                       pin_memory=True,
                                       persistent_workers=True)

    test_sampler = Data.distributed.DistributedSampler(test_dataset) if distributed else None
    test_loader = Data.DataLoader(test_dataset,
                                  batch_size=batch_size * 2,
                                  sampler=test_sampler,
                                  num_workers=8,
                                  shuffle=False,
                                  drop_last=False,
                                  pin_memory=True)
    return train_loader, train_sampler, test_loader, test_sampler
