from torchvision import transforms

import torch
import random
from PIL import ImageFilter 


mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

class GaussianBlur(object):
    """Applies Gaussian Blur to the image."""
    def __init__(self, sigma=[0.1, 2.0]):
        self.sigma = sigma

    def __call__(self, img):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        return img.filter(ImageFilter.GaussianBlur(radius=sigma))


def get_transform(transform_type='imagenet', image_size=32, args=None):

    if transform_type == 'imagenet':

        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)
        if hasattr(args, 'interpolation'):
            interpolation = args.interpolation
        else:
            interpolation = 3
        if hasattr(args, 'crop_pct'):
            crop_pct = args.crop_pct
        else:
            crop_pct = 0.875

        train_transform = transforms.Compose([
            transforms.Resize(int(image_size / crop_pct), interpolation),
            transforms.RandomCrop(image_size),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ColorJitter(),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=torch.tensor(mean),
                std=torch.tensor(std))
        ])

        test_transform = transforms.Compose([
            transforms.Resize(int(image_size / crop_pct), interpolation),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=torch.tensor(mean),
                std=torch.tensor(std))
        ])
    
    else:

        raise NotImplementedError

    return (train_transform, test_transform)

def get_strong_augmentations(image_size=32, args=None):
    
    if hasattr(args, 'interpolation'):
        interpolation = args.interpolation
    else:
        interpolation = 3
    if hasattr(args, 'crop_pct'):
        crop_pct = args.crop_pct
    else:
        crop_pct = 0.875
    
    if args.dataset_name in ['cifar10', 'cifar100', 'imagenet_100', 'imagenet_1k']:
        strong_augmentation = transforms.Compose([
            transforms.Resize(int(image_size / crop_pct), interpolation),
            transforms.RandomResizedCrop(size=image_size, scale=(0.3, 1.0)),  # More aggressive cropping
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                                        saturation=0.4, hue=0.1)], p=0.8),
            transforms.RandAugment(num_ops = 2, magnitude = 9, num_magnitude_bins = 20, interpolation = transforms.InterpolationMode.NEAREST, fill = None),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
            # transforms.RandomSolarize(threshold=128, p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=torch.tensor(mean),
                std=torch.tensor(std))
        ])
    elif args.dataset_name in ['cub', 'scars', 'herbarium_19', 'aircraft']:
        strong_augmentation = transforms.Compose([
            transforms.Resize(int(image_size / crop_pct), interpolation),
            transforms.RandomHorizontalFlip(),
            transforms.RandomResizedCrop(size=image_size, scale=(0.3, 1.0)),  # More aggressive cropping
            # transforms.RandomApply([transforms.ColorJitter(brightness=0.4, contrast=0.4,
            #                                             saturation=0.4, hue=0.1)], p=0.8),
            transforms.RandAugment(num_ops = 2, magnitude = 9, num_magnitude_bins = 20, interpolation = transforms.InterpolationMode.NEAREST, fill = None),
            # transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.8),
            transforms.RandomSolarize(threshold=128, p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=torch.tensor(mean),
                std=torch.tensor(std))
        ])
    # elif args.dataset_name in []:
    #     strong_augmentation = transforms.Compose([
    #         transforms.Resize(int(image_size / crop_pct), interpolation),
    #         transforms.RandomResizedCrop(size=image_size, scale=(0.3, 1.0)),  # More aggressive cropping
    #         transforms.RandomHorizontalFlip(),
    #         transforms.RandomApply([transforms.ColorJitter(brightness=0.4, contrast=0.4,
    #                                                     saturation=0.4, hue=0.1)], p=0.8),
    #         transforms.RandAugment(num_ops = 2, magnitude = 9, num_magnitude_bins = 20, interpolation = transforms.InterpolationMode.NEAREST, fill = None),
    #         transforms.RandomGrayscale(p=0.2), # 0.2
    #         transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.8),
    #         # transforms.RandomSolarize(threshold=128, p=0.5),
    #         transforms.ToTensor(),
    #         transforms.Normalize(
    #             mean=torch.tensor(mean),
    #             std=torch.tensor(std))
    #     ])
    # elif args.dataset_name in []:
    #     strong_augmentation = transforms.Compose([
    #         transforms.Resize(int(image_size / crop_pct), interpolation),
    #         transforms.RandomResizedCrop(size=image_size, scale=(0.3, 1.0)),  # More aggressive cropping
    #         transforms.RandomHorizontalFlip(p=0.3),
    #         # transforms.ColorJitter(),
    #         transforms.RandAugment(num_ops = 2, magnitude = 9, num_magnitude_bins = 20, interpolation = transforms.InterpolationMode.NEAREST, fill = None),
    #         transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.8),
    #         transforms.ToTensor(),
    #         transforms.Normalize(
    #             mean=torch.tensor(mean),
    #             std=torch.tensor(std))
    #     ])
        
    return strong_augmentation

def get_weak_augmentation(image_size=32, args=None):
    if hasattr(args, 'interpolation'):
        interpolation = args.interpolation
    else:
        interpolation = 3
    if hasattr(args, 'crop_pct'):
        crop_pct = args.crop_pct
    else:
        crop_pct = 0.875
    weak_trans = transforms.Compose([
        transforms.Resize(int(image_size / crop_pct), interpolation),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=torch.tensor(mean),
            std=torch.tensor(std))
    ])
    return weak_trans

def ema_transform(image_size=32, args=None):
    
    if hasattr(args, 'interpolation'):
        interpolation = args.interpolation
    else:
        interpolation = 3
    if hasattr(args, 'crop_pct'):
        crop_pct = args.crop_pct
    else:
        crop_pct = 0.875
        
    weak_augmentation = transforms.Compose([
        transforms.Resize(int(image_size / crop_pct), interpolation),
        transforms.RandomResizedCrop(size=image_size, scale=(0.9, 1.0)),  # Assuming input size is 32x32, adjust accordingly
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=torch.tensor(mean),
            std=torch.tensor(std))
    ])
            
    strong_augmentation = transforms.Compose([
        transforms.Resize(int(image_size / crop_pct), interpolation),
        transforms.RandomResizedCrop(size=image_size, scale=(0.3, 1.0)),  # More aggressive cropping
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomApply([transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                                    saturation=0.4, hue=0.1)], p=0.2),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
        transforms.RandomSolarize(threshold=128, p=0.5),
        # transforms.ColorJitter(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=torch.tensor(mean),
            std=torch.tensor(std))
    ])
    
    return (weak_augmentation, strong_augmentation)
