import glob
import numpy as np
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
from dreamsim import dreamsim
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms as T
import clip
import seaborn as sn
import gc
import multiprocessing
from typing import List, Dict
from skimage.metrics import structural_similarity
from scipy.stats import wasserstein_distance,pearsonr,binom,linregress
from skimage import exposure
from skimage.filters import sobel
import pickle
import time

import warnings
warnings.filterwarnings('ignore')

from sklearnex import patch_sklearn, config_context
patch_sklearn()

gpu_id = 2

def compute_dreamsim(paths):
    def compute_sim(model, preprocess, path1, path2):
        img1 = Image.open(path1).resize((224,224))
        img1 = preprocess(img1).to(f"cuda:{gpu_id}")
        img2 = Image.open(path2).resize((224,224))
        img2 = preprocess(img2).to(f"cuda:{gpu_id}")
        dist = model(img1, img2) # The model takes an RGB image from [0, 1], size batch_sizex3x224x224
        return dist.cpu().detach().numpy()[0]
        
    model, preprocess = dreamsim(pretrained=True, device=f"cuda:{gpu_id}")
    dists = []
    for i in tqdm(range(len(paths)), total=len(paths), leave=False):
        dists.append(compute_sim(model, preprocess, paths[i][0], paths[i][1]))
    # cleanup
    del model
    del preprocess
    with torch.no_grad():
        torch.cuda.empty_cache()
    gc.collect()
    return dists

def compute_clipsim(paths):
    def pairwise_corr_all(ground_truth, predictions):
        r = np.corrcoef(ground_truth, predictions)#cosine_similarity(ground_truth, predictions)#
        r = r[:len(ground_truth), len(ground_truth):]  # rows: groundtruth, columns: predicitons
        #print(r.shape)
        # congruent pairs are on diagonal
        congruents = np.diag(r)
        #print(congruents)
        
        # for each column (predicition) we should count the number of rows (groundtruth) that the value is lower than the congruent (e.g. success).
        success = r < congruents
        success_cnt = np.sum(success, 0)
        
        # note: diagonal of 'success' is always zero so we can discard it. That's why we divide by len-1
        perf = np.mean(success_cnt) / (len(ground_truth)-1)
        p = 1 - binom.cdf(perf*len(ground_truth)*(len(ground_truth)-1), len(ground_truth)*(len(ground_truth)-1), 0.5)
        
        return perf, p

    def compute_sim(net, path1, path2):
        # Normalize and preprocess images
        normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                         std=[0.26862954, 0.26130258, 0.27577711])
        img1 = Image.open(path1)
        img1 = T.functional.resize(img1, (224, 224))
        img1 = T.functional.to_tensor(img1).float()
        img1 = normalize(img1)
        img1 = img1.to(f"cuda:{gpu_id}")
    
        img2 = Image.open(path2)
        img2 = T.functional.resize(img2, (224, 224))
        img2 = T.functional.to_tensor(img2).float()
        img2 = normalize(img2)
        img2 = img2.to(f"cuda:{gpu_id}")
    
        # Forward pass through the network
        with torch.no_grad():
            output1 = net(img1.unsqueeze(0))  # Add batch dimension
            output2 = net(img2.unsqueeze(0))  # Add batch dimension
    
        # Calculate similarity metric (adjust according to your method)
        # Example: Using cosine similarity
        sim = F.cosine_similarity(output1, output2, dim=1)
        return sim.item()
        
    model, _ = clip.load("ViT-L/14", device=f"cuda:{gpu_id}")
    net = model.visual
    net = net.to(torch.float32)
    # net.register_forward_hook(fn)
    net.to(f"cuda:{gpu_id}")   

    dists = []
    for i in tqdm(range(len(paths)), total=len(paths), leave=False):
        dists.append(compute_sim(net, paths[i][0], paths[i][1]))
    # cleanup
    del model
    del net
    with torch.no_grad():
        torch.cuda.empty_cache()
    gc.collect()
    return dists

# for all metris lower means more similar
def load_image(path, mode='RGB'):
    return np.array(Image.open(path).convert(mode).resize((224,224)))
    
def pix_corr(paths):
    img1, img2 = load_image(paths[0]), load_image(paths[1])
    return np.corrcoef(img1.reshape(1,-1), img2.reshape(1,-1))[0,1]

def avg_color_diff(paths):
    img1, img2 = load_image(paths[0]), load_image(paths[1])
    return np.abs(np.mean(img1) - np.mean(img2))

def contrast_diff(paths):
    img1, img2 = load_image(paths[0], 'L'), load_image(paths[1], 'L')
    return np.abs(np.std(img1) - np.std(img2))

def brightness_diff(paths):
    img1, img2 = load_image(paths[0], 'L'), load_image(paths[1], 'L')
    return np.abs(np.mean(img1) - np.mean(img2))

def frequency_diff(paths):
    img1, img2 = load_image(paths[0], 'L'), load_image(paths[1], 'L')
    size = (min(img1.shape[1], img2.shape[1]), min(img1.shape[0], img2.shape[0]))
    img1, img2 = Image.fromarray(img1).resize(size), Image.fromarray(img2).resize(size)
    img1, img2 = np.array(img1), np.array(img2)
    fft1, fft2 = np.fft.fft2(img1), np.fft.fft2(img2)
    magnitude1, magnitude2 = np.abs(fft1), np.abs(fft2)
    return np.mean(np.abs(magnitude1 - magnitude2))
    
def ssim(paths):
    img1, img2 = load_image(paths[0])/255.0, load_image(paths[1])/255.0
    return 1-structural_similarity(img1, img2, channel_axis=2, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, data_range=1.0)
    #return 1 - ssim(img1, img2, data_range=img1.max() - img1.min())

def psnr_diff(paths):
    img1, img2 = load_image(paths[0]), load_image(paths[1])
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return np.inf
    max_pixel = 255.0
    psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
    return psnr

def mse_diff(paths):
    img1, img2 = load_image(paths[0]), load_image(paths[1])
    return np.mean((img1 - img2) ** 2)

def nrmse_diff(paths):
    img1, img2 = load_image(paths[0]), load_image(paths[1])
    rmse = np.sqrt(np.mean((img1 - img2) ** 2))
    return rmse / (img1.max() - img1.min())

# TODO: not sure this one is right
def histogram_intersection_diff(paths):
    img1, img2 = load_image(paths[0]), load_image(paths[1])
    hist1, _ = np.histogram(img1.ravel(), bins=256, range=(0, 256))
    hist2, _ = np.histogram(img2.ravel(), bins=256, range=(0, 256))
    intersection = np.sum(np.minimum(hist1, hist2))
    return 1 - intersection / np.sum(hist1)

def emd_diff(paths):
    img1, img2 = load_image(paths[0]), load_image(paths[1])
    hist1, _ = np.histogram(img1.ravel(), bins=256, range=(0, 256))
    hist2, _ = np.histogram(img2.ravel(), bins=256, range=(0, 256))
    return wasserstein_distance(hist1, hist2)

def ncc_diff(paths):
    img1, img2 = load_image(paths[0]), load_image(paths[1])
    img1 = (img1 - np.mean(img1)) / np.std(img1)
    img2 = (img2 - np.mean(img2)) / np.std(img2)
    return 1 - np.mean(img1 * img2)

def nae_diff(paths):
    img1, img2 = load_image(paths[0]), load_image(paths[1])
    return np.sum(np.abs(img1 - img2)) / np.sum(img1)

def chi_squared_diff(paths):
    img1, img2 = load_image(paths[0]), load_image(paths[1])
    hist1, _ = np.histogram(img1.ravel(), bins=256, range=(0, 256))
    hist2, _ = np.histogram(img2.ravel(), bins=256, range=(0, 256))
    return np.sum((hist1 - hist2) ** 2 / (hist1 + hist2 + 1e-10)) / 2

def cosine_similarity_diff(paths):
    img1, img2 = load_image(paths[0]), load_image(paths[1])
    img1_flat, img2_flat = img1.ravel(), img2.ravel()
    return (np.dot(img1_flat, img2_flat) / (np.linalg.norm(img1_flat) * np.linalg.norm(img2_flat)))

def histogram_matching_diff(paths):
    img1, img2 = load_image(paths[0]), load_image(paths[1])
    
    if img1.ndim == 3 and img2.ndim == 3:  # Color images
        matched_img2 = np.zeros_like(img2)
        for channel in range(img1.shape[2]):
            matched_img2[:, :, channel] = exposure.match_histograms(img2[:, :, channel], img1[:, :, channel])
    else:  # Grayscale images
        matched_img2 = exposure.match_histograms(img2, img1)
    
    return np.mean(np.abs(img1 - matched_img2))

def edge_energy_diff(paths):
    img1, img2 = load_image(paths[0], 'L'), load_image(paths[1], 'L')
    edges1 = sobel(img1)
    edges2 = sobel(img2)
    return np.mean(np.abs(edges1 - edges2))

def compute_metric(metric_func, image_pairs):
    with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool:
        dists = list(pool.map(metric_func, image_pairs))
    return dists

def compute_metric_single_thread(metric_func, image_pairs):
    dists = metric_func(image_pairs)
    return dists

def get_dists(orig_images, generated_images):
    dist_functions = {
        'clipsim':compute_clipsim,
        'dreamsim':compute_dreamsim,
        'pix_corr':pix_corr,
        'avgColor': avg_color_diff,
        'contrastDiff': contrast_diff,
        'brightnessDiff': brightness_diff,
        #'frequencyDiff': frequency_diff,
        'ssim': ssim,
        'psnr': psnr_diff,
        'mse': mse_diff,
        #'nrmse': nrmse_diff,
        #'histogramIntersection': histogram_intersection_diff,
        #'emd': emd_diff,
        'ncc': ncc_diff,
        #'nae': nae_diff,
        #'chiSquared': chi_squared_diff,
        'cosineSimilarity': cosine_similarity_diff,
        #'histogramMatching': histogram_matching_diff,
        'edgeEnergy': edge_energy_diff
    }
    
    image_pairs = list(zip(orig_images, generated_images))
    all_dists = {}
    pbar = tqdm(dist_functions.items(), leave=False)
    for dist_name, dist_func in pbar:
        pbar.set_description(f"Computing {dist_name}")
        if dist_name in ['dreamsim','clipsim']:
            all_dists[dist_name] = compute_metric_single_thread(dist_func, image_pairs)
        else:
            all_dists[dist_name] = compute_metric(dist_func, image_pairs)
    
    return all_dists

def main():
    all_results = {}
    # original NSD images sorted by id
    original_image_path = '/storage/user1/StableDiffusionReconstruction-brainbits/decoded/originals'
    original_images = sorted(glob.glob(original_image_path + '/*.png'), key=lambda q: int(q.split('/')[-1].split('_')[0]))
    
    # braindiffuser bits bottlneck
    for subj in ['subj01','subj02','subj05','subj07']:
        bottlenecks = [1, 2, 3, 4, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 500, 1000, 2000, 5000, 8000, 15000]
        scan_path_fn = lambda x: sorted(glob.glob(x + '/*.png'), key=lambda q: int(q.split('/')[-1].split('.')[0]))
        def get_path_fn(x):
            if x == 15000: # 15000 is the original brain diffusers image reconstruction
                return f'/storage/user1/projects/image_metrics/{subj}_2/brain_diffuser'
            else:
                return f'/storage/user1/BrainBitsWIP/results/versatile_diffusion/{subj}/train_single_{x}'
        generated_images = [scan_path_fn(get_path_fn(x)) for x in bottlenecks]
        all_results['braindiffuser_bits'] = {}
        all_results['braindiffuser_bits'][subj] = {}
        pbar = tqdm(enumerate(bottlenecks), total=len(bottlenecks))
        for i, bottleneck in pbar:
            pbar.set_description(f"BrainDiffuser bottleneck {bottleneck}")
            all_results['braindiffuser_bits'][subj][bottleneck] = get_dists(original_images, generated_images[i])
    
    # random baseline
    bottlenecks = [1, 2, 3, 4, 5, 50, 55, 500, 1000]
    scan_path_fn = lambda x: sorted(glob.glob(x + '/*.png'), key=lambda q: int(q.split('/')[-1].split('.')[0]))
    get_path_fn = lambda x: f'/storage/user1/BrainBitsWIP/results/versatile_diffusion/subj01/train_single_random_brain_{x}'
    generated_images = [scan_path_fn(get_path_fn(x)) for x in bottlenecks]
    all_results['braindiffuser_random'] = {}
    pbar = tqdm(enumerate(bottlenecks), total=len(bottlenecks))
    for i, bottleneck in pbar:
        pbar.set_description(f"Random bottleneck {bottleneck}")
        all_results['braindiffuser_random'][bottleneck] = get_dists(original_images, generated_images[i])

    # text baseline
    model_names = ["vd_text_1","vd_text_2","vd_text_3","vd_text_4","vd_text_5","xl_text_5"]
    folder_names = {"vd_text_1":"1words_1715401774",
                    "vd_text_2":"2words_1715376540",
                    "vd_text_3":"3words_1715134542",
                    "vd_text_4":"4words_1715386819",
                    "vd_text_5":"5words_1715039198",
                   }
    
    scan_path_fn = lambda x: sorted(glob.glob(x + '/*.png'), key=lambda q: int(q.split('/')[-1].split('_')[0]))
    def get_path_fn(x):
        if x[:7]=="vd_text":
            return f'/storage/user1/projects/Versatile-Diffusion-old/{folder_names[x]}'
        elif x =="xl_text_5":
            return '/storage/user1/projects/stable_diffusion_xl/xl_gen'
        else:
            raise(f'Invalid model name: {x}')
    generated_images = [scan_path_fn(get_path_fn(x)) for x in model_names]
    all_results['text_baseline'] = {}
    pbar = tqdm(enumerate(model_names), total=len(model_names))
    for i, model_name in pbar:
        pbar.set_description(f"Text baseline {model_name}")
        all_results['text_baseline'][model_name] = get_dists(original_images, generated_images[i])

    #takagi method
    bottlenecks = [1, 5, 10, 50, 100, 500, 1000, 15000]
    for subj in ['subj01','subj02','subj05','subj07']:
        scan_path_fn = lambda x: sorted([w for w in glob.glob(x + '/*.png') if int(w.split('/')[-1].split('_')[-1].split('.')[0])==0], key=lambda q: int(q.split('/')[-1].split('.')[0].split('_')[0]))
        def get_path_fn(x):
            if x == 15000:
                return f'/storage/user1/projects/image_metrics/image-text/{subj}/samples'
            else:
                return f'/storage/user1/projects/image_metrics/bottleneck_training/{subj}/BottleneckLinear/bottleneck_dim={x}/decoded_test_imgs'
        generated_images = [scan_path_fn(get_path_fn(x)) for x in bottlenecks]
        all_results['takagi_bits'] = {}
        all_results['takagi_bits'][subj] = {}
        pbar = tqdm(enumerate(bottlenecks), total=len(bottlenecks))
        for i, bottleneck in pbar:
            pbar.set_description(f"Takagi bottleneck {bottleneck}")
            all_results['takagi_bits'][subj][bottleneck] = get_dists(original_images, generated_images[i])
    
    return all_results

if __name__ == '__main__':
    all_results = main()
    timestamp = int(time.time())
    with open(f'all_metric_results_{timestamp}.pickle', 'wb') as f:
        pickle.dump(all_results, f)