import sys
import os.path as osp
import time
from PIL import Image
import random
import numpy as np

from sklearn.metrics import multilabel_confusion_matrix, classification_report
from sklearn.metrics import precision_score, recall_score, f1_score

import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchmetrics.functional as MF
from timm.data.auto_augment import auto_augment_transform, rand_augment_transform

import datasets
from datasets.transforms import ResizeImage
from utils.metric import accuracy, ConfusionMatrix
from utils.meter import AverageMeter, ProgressMeter
from datasets.imagelist import MultipleDomainsDataset

import clip

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def seed_fix(n):
    random.seed(n)
    np.random.seed(n)
    torch.manual_seed(n)
    if torch.backends.cudnn.enabled:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True


def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False
    for module in model.modules():
        if "BatchNorm" in type(module).__name__:
            module.momentum = 0.0
    model.eval()
    return model


def convert_models_to_fp32(model):
    for p in model.parameters():
        p.data = p.data.float()
        if p.grad:
            p.grad.data = p.grad.data.float()


def model_load(model, ckpt, prompt_key):
    pretrained_dict = torch.load(ckpt)
    
    model_dict = model.state_dict()
    prompt_dict = {}
    for k, v in pretrained_dict.items():
        for key in prompt_key:
            if (key in k): #and ("ctx_n" not in key):
                prompt_dict[k] = v
        
    assert len(prompt_dict.keys())
    print(prompt_dict.keys())
    model_dict.update(prompt_dict)
    model.load_state_dict(model_dict)
    return model


def make_fewshot_dataset(dataset, num_shot):
    samples_dict = {}
    # division
    for sample in dataset.samples:
        cls_id = sample[-1]
        if cls_id not in samples_dict.keys():
            samples_dict[cls_id] = []
        samples_dict[cls_id].append(sample)
    # check
    for value in samples_dict.values():
        assert len(value) >= num_shot
    # selection
    for key in samples_dict.keys():
        samples_dict[key] = random.sample(samples_dict[key], num_shot)
    samples = []
    for value in samples_dict.values():
        samples += value
    dataset.samples = samples

    return dataset


def get_model_names():
    return sorted(
        [
            "CLIPRN50", "CLIPRN101", "CLIPViT-B/16", "CLIPViT-B/32"
        ]
    )


def get_model(model_name, pretrain=True):
    if "CLIP" in model_name:
        ## Load CLIP model
        clip_model, preprocess = clip.load(model_name[4:], device=device, jit=False)
        convert_models_to_fp32(clip_model)
        clip_model = freeze_model(clip_model)
        return clip_model.eval()
    if model_name in models.__dict__:
        # load models from tllib.vision.models
        backbone = models.__dict__[model_name](pretrained=pretrain)
    else:
        # load models from pytorch-image-models
        backbone = timm.create_model(model_name, pretrained=pretrain)
        try:
            backbone.out_features = backbone.get_classifier().in_features
            backbone.reset_classifier(0, '')
        except:
            backbone.out_features = backbone.head.in_features
            backbone.head = nn.Identity()
    return backbone


def get_dataset_names():
    return sorted(
        name for name in datasets.__dict__
        if not name.startswith("__") and callable(datasets.__dict__[name])
    )


def get_dataset(dataset_name, root, source, target, train_source_transform, val_transform, train_target_transform=None, val_target_transform=None, fewshot=0):
    
    if dataset_name in datasets.__dict__:
        # load datasets from tllib.vision.datasets
        dataset = datasets.__dict__[dataset_name]

        def concat_dataset(tasks, start_idx, fewshot=0, **kwargs):
            # return ConcatDataset([dataset(task=task, **kwargs) for task in tasks])
            if fewshot:
                return MultipleDomainsDataset([make_fewshot_dataset(dataset(task=task, **kwargs), fewshot) for task in tasks], tasks,
                                          domain_ids=list(range(start_idx, start_idx + len(tasks))))
            else:
                return MultipleDomainsDataset([dataset(task=task, **kwargs) for task in tasks], tasks,
                                          domain_ids=list(range(start_idx, start_idx + len(tasks))))

        train_source_dataset = concat_dataset(root=root, tasks=source, download=True, transform=train_source_transform, target_transform=train_target_transform,
                                              start_idx=0, fewshot=fewshot)
        val_dataset = concat_dataset(root=root, tasks=target, download=True, transform=val_transform, target_transform=val_target_transform,
                                     start_idx=len(source), fewshot=0)

        class_names = train_source_dataset.datasets[0].classes
        num_classes = len(class_names)
    else:
        raise NotImplementedError(dataset_name)
    return train_source_dataset, val_dataset, num_classes, class_names


def refine_classname(class_names):
    for i, class_name in enumerate(class_names):
        class_names[i] = class_name.lower().replace('_', ' ').replace('-', ' ')
    return class_names


def validate(val_loader, model, args, device) -> float:
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()
    if args.per_class_eval:
        confmat = ConfusionMatrix(len(args.class_names))
    else:
        confmat = None

    with torch.no_grad():
        end = time.time()
        for i, data in enumerate(val_loader):
            images, target = data[:2]
            images = images.to(device)
            target = target.to(device)

            # compute output
            output = model(images)
            loss = F.cross_entropy(output, target)

            # measure accuracy and record loss
            acc1, = accuracy(output, target, topk=(1,))
            if confmat:
                confmat.update(target, output.argmax(1))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1.item(), images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))
        if confmat:
            print(confmat.format(args.class_names))

    return top1.avg


def validate_clip(val_loader, model, args, device, mode=None) -> float:
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, data in enumerate(val_loader):
            images_s, images_t  = data[:2]
            if mode == "s":
                images = images_s.to(device)
            elif mode == "t":
                images = images_t.to(device)

            # compute output
            if mode == "v":
                images = torch.cat((images_s, images_t), dim=0)
                images = images.to(device)
                
                _, f = model(images)                
                f_s, f_t = f.chunk(2, dim=0)
                logit_scale = model.logit_scale.exp()
                logits_per_image_s = logit_scale * f_s @ f_t.t()
                logits_per_image_t = logits_per_image_s.t()
                ground_truth = torch.arange(len(f_s),dtype=torch.long,device=device)
                loss = (F.cross_entropy(logits_per_image_s,ground_truth) + F.cross_entropy(logits_per_image_t,ground_truth))/2
                ground_truth = torch.cat((ground_truth, ground_truth), dim=0)
                output = torch.cat((logits_per_image_s, logits_per_image_t), dim=0)
            else:
                output,_ = model(images)
                ground_truth = torch.arange(len(images),dtype=torch.long,device=device)
                loss = F.cross_entropy(output, ground_truth)

            # measure accuracy and record loss
            acc1, = accuracy(output, ground_truth, topk=(1,))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1.item(), images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))

    return top1.avg


def get_train_transform(resizing='default', scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), random_horizontal_flip=True,
                        random_color_jitter=False, resize_size=224, norm_mean=(0.485, 0.456, 0.406),
                        norm_std=(0.229, 0.224, 0.225), auto_augment=None):
    """
    resizing mode:
        - default: resize the image to 256 and take a random resized crop of size 224;
        - cen.crop: resize the image to 256 and take the center crop of size 224;
        - res: resize the image to 224;
    """
    transformed_img_size = 224
    if resizing == 'default':
        transform = T.Compose([
            ResizeImage(256),
            T.RandomResizedCrop(224, scale=scale, ratio=ratio)
        ])
    elif resizing == 'cen.crop':
        transform = T.Compose([
            ResizeImage(256),
            T.CenterCrop(224)
        ])
    elif resizing == 'ran.crop':
        transform = T.Compose([
            ResizeImage(256),
            T.RandomCrop(224)
        ])
    elif resizing == 'res.':
        transform = ResizeImage(resize_size)
        transformed_img_size = resize_size
    else:
        raise NotImplementedError(resizing)
    transforms = [transform]
    if random_horizontal_flip:
        transforms.append(T.RandomHorizontalFlip())
    if auto_augment:
        aa_params = dict(
            translate_const=int(transformed_img_size * 0.45),
            img_mean=tuple([min(255, round(255 * x)) for x in norm_mean]),
            interpolation=Image.BILINEAR
        )
        if auto_augment.startswith('rand'):
            transforms.append(rand_augment_transform(auto_augment, aa_params))
        else:
            transforms.append(auto_augment_transform(auto_augment, aa_params))
    elif random_color_jitter:
        transforms.append(T.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5))
    transforms.extend([
        T.ToTensor(),
        T.Normalize(mean=norm_mean, std=norm_std)
    ])
    return T.Compose(transforms)


def get_val_transform(resizing='default', resize_size=224,
                      norm_mean=(0.485, 0.456, 0.406), norm_std=(0.229, 0.224, 0.225)):
    if resizing == 'default':
        transform = T.Compose([
            ResizeImage((resize_size,resize_size)),
        ])
    else:
        raise NotImplementedError(resizing)
    return T.Compose([
        transform,
        # T.ToTensor(),
        # T.Normalize(mean=norm_mean, std=norm_std)
    ])

def get_negative_transform(resize_size=224):
    transformed_img_size = resize_size
    transforms = [T.RandomResizedCrop(resize_size, scale=(0.85, 1.15), ratio=(3. / 4., 4. / 3.))]
    #transforms.append(T.RandomPerspective())
    transforms.append(T.RandomHorizontalFlip())
    transforms.append(T.RandomAffine(5))
    transforms.append(T.RandomRotation(5, expand=False))
    #transforms.append(T.Normalize(mean=norm_mean, std=norm_std))
    return T.Compose(transforms)

def get_positive_transform(resize_size=224, norm_mean=(0.485, 0.456, 0.406), norm_std=(0.229, 0.224, 0.225)):
    color_jitter = T.ColorJitter(
        # brightness=(0.01, 2.5),
        # contrast=(0.01, 4.5),
        # saturation=(0.01, 4), 
        hue=(-0.5, 0.5)
        )
    gray_scale = T.RandomGrayscale(p=0.2)
    invert = T.RandomInvert()
    solarize = T.RandomSolarize(threshold=0.75)
    #posterize = torchvision.transforms.RandomPosterize(bits=2)
    adjust_sharpness = T.RandomAdjustSharpness(sharpness_factor=2)
    gaussian_blur = T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))
    #perspective = torchvision.transforms.RandomPerspective(distortion_scale=0.6, p=1.0)
    #elastic_transform = torchvision.transforms.ElasticTransform(alpha=250.0)
    autocontrast = T.RandomAutocontrast()
    #equalize = torchvision.transforms.RandomEqualize()
    transforms = T.RandomApply(
        torch.nn.ModuleList([
            solarize,
            #posterize,
            adjust_sharpness,
            gaussian_blur,
            #perspective,
            #elastic_transform,
            autocontrast,
            #equalize
        ]), p=0.3)
    
    aug_config = T.Compose(
        [
            ResizeImage((resize_size, resize_size)),
            color_jitter,
            #invert,
            #transforms,
            T.ToTensor(),
            T.Normalize(mean=norm_mean, std=norm_std)
        ]
    )

    return aug_config