from classification.autoaugment import CIFAR10Policy
from models.ssl_models import *
import torchvision
from lightly.data.dataset import LightlyDataset
import lightly.data.collate as lightly_collate
from torch.utils.data import ConcatDataset

# logs_root_dir = os.path.join(os.getcwd(), 'benchmark_logs')

# set max_epochs to 800 for long run (takes around 10h on a single V100)


# Use SimCLR augmentations, additionally, disable blur for cifar10
collate_fn = lightly_collate.SimCLRCollateFunction(
    input_size=32,
    gaussian_blur=0.,
)

# Multi crop augmentation for SwAV, additionally, disable blur for cifar10
swav_collate_fn = lightly_collate.SwaVCollateFunction(
    crop_sizes=[32],
    crop_counts=[2],  # 2 crops @ 32x32px
    crop_min_scales=[0.14],
    gaussian_blur=0,
)

# Multi crop augmentation for DINO, additionally, disable blur for cifar10
dino_collate_fn = lightly_collate.DINOCollateFunction(
    global_crop_size=32,
    n_local_views=0,
    gaussian_blur=(0, 0, 0),
)


import torch
from torchvision import transforms

def get_data_loaders(batch_size: int, model=None, dataset_train_ssl=None,
                     dataset_train_kNN=None, dataset_test=None):
    """Helper method to create dataloaders for ssl, kNN train and kNN test

    Args:
        batch_size: Desired batch size for all dataloaders
    """
    col_fn = collate_fn
    if model == SwaVModel:
        col_fn = swav_collate_fn
    elif model == DINOModel:
        col_fn = dino_collate_fn

    dataloader_train_ssl = torch.utils.data.DataLoader(
        dataset_train_ssl,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=col_fn,
        drop_last=True,
        num_workers=num_workers
    )

    # for train classifier, we need to shuffle this! (at the moment it's working since we have augment..)
    dataloader_train_eval = torch.utils.data.DataLoader(
        dataset_train_kNN,
        batch_size=batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=num_workers
    )

    dataloader_test_eval = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=num_workers
    )

    return dataloader_train_ssl, dataloader_train_eval, dataloader_test_eval


def load_datasets(augment=False, use_imagenet_transforms=False):
    dataset_mean, dataset_std = (0.5071, 0.4866, 0.4409), (0.2673, 0.2564, 0.276)
    '''
    augment is for classification! for SSL, in get_data_loaders, there is the augmentations for train!
    This is used in three places, let's get things in order.
    SSL is usually trained with Imagenet normalization - we will do the same so:
    
    SSL Train: (augment=False, use_imagenet_transforms=True) - notice that collate_fn already normalizes by imagenet so 
    it works out
    
    Linear Evaluation: (augment=False, use_imagenet_transforms=True) - same as written above
    
    Supervised Classification: (augment=True/False - depends on what we want to train, use_imagenet_transforms=False)
    '''

    if not use_imagenet_transforms:
        test_transforms = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(dataset_mean, dataset_std)
        ])
    else:
        test_transforms = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(lightly_collate.imagenet_normalize['mean'],
                                             lightly_collate.imagenet_normalize['std'])
        ])

    cur_transforms = []
    if augment:
        cur_transforms.extend([
            transforms.RandomCrop(size=32, padding=4),
            CIFAR10Policy()])

    cur_transforms.extend([transforms.ToTensor(),
                           torchvision.transforms.Normalize(dataset_mean, dataset_std)])

    train_transform = transforms.Compose(cur_transforms)

    # train_ssl
    cifar100 = torchvision.datasets.CIFAR100(download=True)
    dataset_train_ssl = LightlyDataset.from_torch_dataset(cifar100)

    # train_KNN
    cur_knn_transform = train_transform if augment else test_transforms
    print(f"cur_knn_transform:{cur_knn_transform}")
    cifar100 = torchvision.datasets.CIFAR100(download=True)
    dataset_train_kNN = LightlyDataset.from_torch_dataset(cifar100, transform=cur_knn_transform)

    # test
    cifar100_test = torchvision.datasets.CIFAR100(train=False, download=True)
    dataset_test = LightlyDataset.from_torch_dataset(cifar100_test, transform=test_transforms)

    dataloader_train_ssl, dataloader_train_eval, dataloader_test_eval = get_data_loaders(
        batch_size=batch_size,
        model=VICRegModel,
        dataset_train_ssl=dataset_train_ssl,
        dataset_train_kNN=dataset_train_kNN,
        dataset_test=dataset_test
    )

    return dataloader_train_ssl, dataloader_train_eval, dataloader_test_eval
