import torch
import os
import lpips
import numpy as np
import torchvision
import torch.optim as optim
import math
import torch.nn as nn
from torchvision.utils import make_grid, save_image
import matplotlib.pyplot as plt
def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

def numpy_to_pil(images):
    """
    Convert a numpy image or a batch of images to a PIL image.
    """
    if images.ndim == 3:
        images = images[None, ...]
    images = (images * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]

    return pil_images


def unflatten_space(feature_map, tensor_shape):  # unsquash spatial dims
    return feature_map.reshape(tensor_shape).clone()

def flatten_space(feature_map):  # squash spatial dims
    return torch.flatten(feature_map, start_dim=-2).clone()  # n x c x (h*w)

def demean(feature_map, dim=-1):
    """removes mean of tensor channels"""
    mu = torch.mean(feature_map, dim=dim, keepdim=True)
    demeaned = -mu + feature_map
    return demeaned, mu

def instance_whiten(batch_feature_map):
    temp_shape = batch_feature_map.shape
    y = flatten_space(batch_feature_map)
    y, mu = demean(y)
    N = y.shape[-1]
    cov = torch.einsum('bcx, bdx -> bcd', y, y) / (N-1)  # compute covs along batch dim
    u, lambduh, _ = torch.svd(cov)
    lambduh_inv_sqrt = torch.diag_embed(lambduh**(-.5))
    zca_whitener = torch.einsum('nab, nbc, ncd -> nad', u, lambduh_inv_sqrt, u.transpose(-2,-1))
    z = torch.einsum('bac, bcx -> bax', zca_whitener, y)
    
    return unflatten_space(z,temp_shape)






def projection_regret(loader, model, scheduler, timestep, cur_time, device, t2,n1,n2):
    LPIPS = lpips.LPIPS(net="vgg").to(device)
    LPIPS2 = nn.MSELoss(reduction='none').to(device)
    recon_stats = []
    normal_stats = []
    model.eval()

    for i_, (clean_images, targets) in enumerate(loader):

        clean_images = clean_images.to(device)
        batch_size = clean_images.shape[0]

        ct_ = (cur_time) * torch.ones([batch_size,1,1,1]).long().to(device)
        ct = timestep[ct_]
        ct2_ = t2 * torch.ones([batch_size,1,1,1]).long().to(device)
        ct2 = timestep[ct2_]

        id_LPIPS = 0 
        ood_LPIPS = 0
        with torch.no_grad():
            for iter_ in range(n1):
                target = scheduler.step(model, clean_images+ torch.randn_like(clean_images) * ct.view(-1,1,1,1),ct)
        
                for iter2_ in range(n2):
                    hello2 = torch.randn_like(clean_images)
                    id_recon = scheduler.step(model, target + hello2 * ct2.view(-1,1,1,1),ct2)
                    ood_recon = scheduler.step(model, clean_images + hello2 * ct2.view(-1,1,1,1),ct2)
                    ood_LPIPS += LPIPS(ood_recon, clean_images).mean([1,2,3])
                    id_LPIPS += LPIPS(id_recon ,target).mean([1,2,3])

        recon_errors2 = ood_LPIPS-id_LPIPS

        difference2 = recon_errors2
        difference2 = recon_errors2.cpu().data.numpy()

        if i_ == 0:
            normal_stats = difference2
        else:
            normal_stats = np.concatenate((normal_stats,difference2),0)

    return normal_stats


def recon_ablation(loader,model,scheduler,timesteps,cur_time,device,n1,mode,mode2):
    LPIPS = lpips.LPIPS(net="vgg").to(device)
    data_stats = []
    for i_, (clean_images, targets) in enumerate(loader):
        clean_images = clean_images.to(device)
        batch_size = clean_images.shape[0]

        t_ = cur_time * torch.ones([batch_size,1,1,1])
        t_ = t_.long().to(device)
        t = timesteps[t_]
        recon_error = 0 

        with torch.no_grad():
            for iter_ in range(n1):
                noisy_images = scheduler.add_noise(clean_images, torch.randn_like(clean_images), t)

                if mode2 =='edm':
                    target_outputs = scheduler.step_dm(model, noisy_images, t).detach()
                if mode2 == 'cm':
                    target_outputs = scheduler.step(model, noisy_images, t).detach()

                if mode == 'l2':
                    recon_error += torch.sum((target_outputs - clean_images) ** 2, dim=[1,2,3])
                if mode == 'lpips':
                    recon_error += LPIPS(target_outputs, clean_images).mean([1,2,3])
                    
        recon_error = recon_error.cpu().data.numpy()

        if i_ == 0:
            data_stats = recon_error
        else:
            data_stats = np.concatenate((data_stats, recon_error), axis=0)
    
    return data_stats



def impaint_lpips(loader, model, scheduler, timesteps, device):
    LPIPS = lpips.LPIPS(net="vgg").to(device)
    model.eval()

    for i_, (clean_images, targets) in enumerate(loader):

        clean_images = clean_images.to(device)
        batch_size = clean_images.shape[0]
        mask_even = torch.zeros([batch_size,3,32,32]).to(device)
        mask_odd = torch.zeros([batch_size,3,32,32]).to(device)
        for ii_ in range(8):
            for jj_ in range(8):
                spec = ii_+jj_
                if spec % 2 ==0:
                    mask_even[:,:,4*ii_:4*ii_+4,4*jj_:4*jj_+4] = 1
                else:
                    mask_odd[:,:,4*ii_:4*ii_+4,4*jj_:4*jj_+4]= 1
    
        for j_ in range(10):
            test_image = clean_images.clone()
            if j_ % 2 == 0:
                distance_ = impaint_mask(test_image, model, scheduler, mask_even, timesteps, device,LPIPS)
            else:
                distance_ = impaint_mask(test_image, model, scheduler, mask_odd, timesteps, device,LPIPS)
            
            distance_ = distance_.cpu().data.numpy()
            distance_ = np.reshape(distance_,[-1,1])
            if j_ == 0:
                specimen = distance_
            else:
                specimen = np.concatenate([specimen,distance_],1)

        specimen = np.median(specimen,1)          

        if i_ == 0:
            recon_stats = specimen

        else:
            recon_stats = np.concatenate((recon_stats,specimen),0)
    
    return recon_stats

def impaint_mask(test_image,model,scheduler,mask,timesteps,device,LPIPS):
    #LPIPS = lpips.LPIPS(net="vgg").to(device)
    with torch.no_grad():
        #print(mask.shape)
        #print(test_image.shape)
        ref_image = torch.mul(mask,test_image)
        batch_size = test_image.shape[0]
 
        for iter_ in range(16):
            cur_time = 17-iter_
            t_ = cur_time * torch.ones([batch_size,1,1,1])
            t_ = t_.long().to(device)
            if iter_ == 0:
                t = timesteps[t_]
            else:
                t = timesteps[t_]
            
            if iter_ == 0:
                noisy_image = scheduler.add_noise(ref_image, torch.randn_like(ref_image), t)

                target_outputs = scheduler.step(model, noisy_image, t).detach()

                result_output = torch.mul(mask,test_image) + torch.mul(1-mask, target_outputs)
            else:
                t2 = torch.sqrt(t ** 2 - 0.002 ** 2)
                noisy_image = scheduler.add_noise(result_output, torch.randn_like(result_output),t)

                target_outputs = scheduler.step(model, noisy_image, t).detach()

                result_output = torch.mul(mask, test_image) + torch.mul(1-mask, target_outputs)
    

        return LPIPS(test_image,result_output).mean([1,2,3])




def projection_regret_whole_score(loader, model, scheduler, timestep, cur_time, device, t2,n1,n2,r):
    LPIPS = lpips.LPIPS(net="vgg").to(device)
    LPIPS2 = nn.MSELoss(reduction='none').to(device)
    recon_stats = []
    normal_stats = []
    model.eval()

    for i_, (clean_images, targets) in enumerate(loader):

        clean_images = clean_images.to(device)
        batch_size = clean_images.shape[0]

        ct_ = (cur_time) * torch.ones([batch_size,1,1,1]).long().to(device)
        ct = timestep[ct_]
        ct2_ = t2 * torch.ones([batch_size,1,1,1]).long().to(device)
        ct2 = timestep[ct2_]

    
        with torch.no_grad():
            for iter_ in range(n1):
                target = clean_images.clone()
                
                for iter3_ in range(r):
                    target = scheduler.step(model, clean_images+ torch.randn_like(clean_images) * ct.view(-1,1,1,1),ct)

                for iter2_ in range(n2):
                    hello2 = torch.randn_like(clean_images)
                    id_recon = scheduler.step(model, target + hello2 * ct2.view(-1,1,1,1),ct2)
                    ood_recon = scheduler.step(model, clean_images + hello2 * ct2.view(-1,1,1,1),ct2)
                    ood_LPIPS = LPIPS(ood_recon, clean_images).mean([1,2,3])
                    id_LPIPS = LPIPS(id_recon ,target).mean([1,2,3])

                    ood_spec = np.reshape(ood_LPIPS.cpu().data.numpy(),[-1,1,1])
                    id_spec = np.reshape(id_LPIPS.cpu().data.numpy(),[-1,1,1])
                    if iter2_ == 0:
                        diff_stats = ood_spec-id_spec
                    else:
                        diff_stats = np.concatenate((diff_stats,ood_spec-id_spec),2)
                if iter_ == 0:
                    diff_stats2 = diff_stats
                else:
                    diff_stats2 = np.concatenate((diff_stats2,diff_stats),1)
                    

        if i_ == 0:
            diff_stats3 = diff_stats2
        else:
            diff_stats3 = np.concatenate((diff_stats3, diff_stats2),0)
    
    return diff_stats3

