import sys
import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from deel.lip.model import Model as LipModel
from deel.utils.yaml_to_params import load_yaml_config,getParams, getFunctionFromModules, dumdict2yaml
from deel.datasets.load_dataset import load_dataset
from deel.utils.yaml_loader import load_model, loadFunctionList
from deel.utils.yaml_loader import load_optimizer_and_loss
from deel.lip.layers import SpectralDense,FrobeniusDense,SpectralConv2D
from deel.lip.extra_layers import BatchCentering
from deel.lip.activations import MaxMin, GroupSort, GroupSort2, FullSort
from tensorflow.keras.layers import Flatten ,Dense,ReLU,Softmax,Lambda
from deel.lip.model import Model as LipModel
from tensorflow.keras import Model as KerasModel
from deel.lip.normalizers import set_stop_grad_spectral,set_grad_passthrough_bjorck
from deel.utils.lip_res_model import set_add_coeff
import deel.lip.extra_layers
from xplique.attributions import (Saliency, GradientInput, IntegratedGradients, SmoothGrad, VarGrad,
                                  SquareGrad, GradCAM, Occlusion, Rise, GuidedBackprop,
                                  GradCAMPP, Lime, KernelShap)
import time
from scipy.stats import spearmanr
import random
import cv2
def spearman_dist(phi1, phi2):
  if len(phi1.shape) == 3:
    phi1 = np.mean(phi1, -1)
    phi2 = np.mean(phi2, -1)
  return 1.0 - np.abs(spearmanr(np.array(phi1).flatten(), np.array(phi2).flatten()))



def get_random_name():
    prefix = ["horrific","nasty","gloomy","hardcore","metalic","creepy","insane","sinister",
                "neocore","cyberpunk","notorious","mephitic","purple","angry","forbidden","heavy",
                "burning","gory","lofi","ancient","transcendent","bloody","fallen","guilty",
                "outrageous","ghostly","darkly","obscure","cold",
                "dramatic","mad","crazy","weird","ultra","resilient"]
    suffix = ["experiment","minotaure","rats","scientist","network","witch","head","master","guru",
                "cypher","l0rd","count","byte","overflow","overdrive",
                "banshee","cockroach","worms","gremlins","goons","yokai","hacker","abyss","worm","troll",
                "machine","computer","bot","robot","blob","misfit","blade","mamba","moon","virus","raptor","quark","freak",
                "AI","crew","thug","mage","mecha","vengeance","lain","tashikoma","horror","boogeyman"]
    return  random.choice(prefix)+"_"+ random.choice(suffix)


class Logger(object):
    def __init__(self,filename):
        self.terminal = sys.stdout
        self.log = open(filename, "w")
   
    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)  

    def flush(self):
        # this flush method is needed for python 3 compatibility.
        # this handles the flush command by doing nothing.
        # you might want to specify some extra behavior here.
        pass    


def gradient_norm(grads,verbose = False):
    total = 0
    nb = 0
    if grads is None :
        return 0
    for i,g in enumerate(grads):
        #if verbose:
        #    tf.print("layer : ",i,"shape :", tf.shape(g),"norm",tf.norm(g)," mean_value ",tf.reduce_mean(tf.abs(g)),output_stream=OUTSTREAM )
        if g is not None:
            total+=tf.norm(g)
        nb+=1
    return total/nb

def fine_tune_model(model,nb_class,size_dense = 256,vanilla_export = False):

    for layer in model.layers:
        layer.trainable = False
    # add new classifier layers
    #x = Flatten()(model.layers[-2].output)
    x = model.layers[-2].output
    x = SpectralDense(size_dense, use_bias = True, niter_bjorck = 7)(x)
    x = GroupSort()(x)
    #x = Dense(size_dense,activation = 'relu')(x)
    #x = GroupSort()(x)    
    output = FrobeniusDense(nb_class, disjoint_neurons=True, use_bias=False)(x)
    model = LipModel(inputs=model.inputs, outputs=output)
    return model


def fine_tune_model_last(model):

    for layer in model.layers[:-1]:
        layer.trainable = False
        print(layer,isinstance(layer,SpectralConv2D) )
        if isinstance(layer, BatchCentering) or isinstance(layer,SpectralConv2D):
            layer.freeze = True

def redressage_grad(grads, t_var, coeff = 0.1, spectral = True):
    n_grads = []
    for grad,var in zip(grads,t_var):
        if 'bias' not in var.name:
            #pas de convolution
            if 'spectral_conv2d' in  var.name and spectral:
                var_norm = tf.sqrt(tf.cast(tf.math.minimum(tf.reduce_prod(var.shape[:-1]),var.shape[-1]),tf.float32))
            else:
                var_norm = tf.sqrt(tf.reduce_sum(var**2.0))
            #var_norm = tf.sqrt(tf.reduce_sum(var**2.0))
            g_norm = tf.sqrt(tf.reduce_sum(grad**2.0))
            g = grad / g_norm
            g = g * var_norm
            g = g * coeff
            n_grads.append(g)
        else :
            n_grads.append(grad)
    return n_grads 

def rescale_grad_unit(grads, t_var):
    n_grads = []
    old_norm = 0
    new_norm = 0
    for grad,var in zip(grads,t_var):
        var_norm = tf.sqrt(tf.reduce_sum(var**2.0))
        
        g_norm = tf.sqrt(tf.reduce_sum(grad**2.0))
        old_norm+=g_norm
        new_norm+=var_norm
        g = grad / g_norm
        g = g * var_norm
        n_grads.append(g)
    coeff =old_norm/new_norm
    rescaled_grads=[]
    for g in n_grads:
        rescaled_grads.append(g/coeff)
    return rescaled_grads 


def load_compiled_model(expe_name, vanilla = False, verbose = True,folder = './results/',softmax = False,suffix = ""):
    rep = folder+expe_name+"/"
    full_config = load_yaml_config(rep+'config.yml')
    set_global_variable(full_config,verbose = True)
    #full_config['network']['params']['verbose'] = False
    model = load_model(getParams(full_config,'network'))
    model.summary()
    model.load_weights(rep+"models/"+expe_name+suffix+'.h5')
    if isinstance(model, LipModel):
        model.condense()
    if vanilla:
        print("vanilla model",isinstance(model, LipModel) )
        if isinstance(model, LipModel):
             model = model.vanilla_export()
    if softmax:
        x = model.output 
        output = Softmax()(x)
        if vanilla :
            model = KerasModel(inputs=model.inputs, outputs=output)
        else :
            model = LipModel(inputs=model.inputs, outputs=output)
    model.compile(loss='categorical_crossentropy',
                  optimizer='adam',
                  metrics=['categorical_accuracy'])
    return model,full_config

def add_softmax(model, temperature =1.):
    x = model.output 
    if temperature!=1.:
        x = Lambda(lambda x: x * temperature)(x)
    output = Softmax()(x)
    if isinstance(model, LipModel):
        model = LipModel(inputs=model.inputs, outputs=output)
    else:
        model = KerasModel(inputs=model.inputs, outputs=output)

    model.compile(loss='categorical_crossentropy',
                  optimizer='adam',
                  metrics=['categorical_accuracy'])
    return model



def rgb_grad(grad_x):
    grad = 0.6*grad_x/np.abs(grad_x.max())
    grad = 0.5+ grad - grad.mean()
    return np.clip(grad,0,1)

def grad_to_color(grad_x):
    grad_mean = tf.reduce_sum(grad_x,axis=-1,keepdims=True)
    grad_mean = grad_mean/tf.norm(grad_mean)
    grad_color = tf.zeros(grad_x.shape).numpy()
    grad_color[:,:,0] = tf.nn.relu(-grad_mean[:,:,0]).numpy()
    #grad_color[:,:,1] = 0.5*tf.nn.relu(-grad_mean[:,:,0]).numpy()
    grad_color[:,:,1] = tf.nn.relu(grad_mean[:,:,0]).numpy()
    #grad_color[:,:,2] = tf.nn.relu(-grad_mean[:,:,0]).numpy()
    

    grad_color = grad_color/np.max(grad_color)
    return grad_color
def grad_to_color2(grad_x):
    grad_mean_pos = tf.reduce_sum(tf.nn.relu(grad_x),axis=-1,keepdims=True).numpy()
    grad_mean_neg = tf.reduce_sum(tf.nn.relu(-grad_x),axis=-1,keepdims=True).numpy()
    grad_mean = tf.reduce_sum(tf.abs(grad_x),axis=-1,keepdims=True)
    #grad_mean = grad_mean/tf.norm(grad_mean)
    grad_color = tf.zeros(grad_x.shape).numpy()
    grad_color[:,:,0] = grad_mean_neg[:,:,0]
    #grad_color[:,:,1] = 0.5*tf.nn.relu(-grad_mean[:,:,0]).numpy()
    grad_color[:,:,1] = grad_mean_pos[:,:,0]
    #grad_color[:,:,2] = tf.nn.relu(grad_mean[:,:,0]).numpy()
    

    grad_color = grad_color/np.max(grad_color)
    return grad_color

def explain_for_binary_dummies(model, x,y,x_rescale = None,counter_coeff = 0.2,div = 255,filename = None):
    image = tf.Variable(np.expand_dims(x, axis=0))

    with tf.GradientTape() as tape:
        tape.watch(image)
        prediction=model(image,training = False)
    grads_class = tape.gradient(prediction, image).numpy()
    #tf.print(prediction,tf.norm(grads_class))
    
    grads_class=grads_class.reshape(x.shape)
    if x_rescale is not None :
        x = x_rescale
    fig, axs = plt.subplots(1, 3,figsize=(10, 12))
    axs[0].imshow(x/div)
    img_exp =  np.clip(x-np.sign(prediction.numpy())*counter_coeff*grads_class/(np.linalg.norm(grads_class))*np.linalg.norm(x),0,div)
    axs[1].imshow( grad_to_color2(grads_class))
    axs[2].imshow(img_exp/255)
    axs[0].set_axis_off()
    axs[1].set_axis_off()
    axs[2].set_axis_off()
    if filename is not None :
        plt.savefig(filename+".jpg",bbox_inches='tight',pad_inches=0)
        plt.close(fig)
    

def img_int_to_float(img):
    img = img/255.
    img = np.clip(img,0.,1.)
    return img

def set_global_variable(full_config,verbose = False):

    set_add_coeff(full_config.get('add_coeff', 0.5))
    set_grad_passthrough_bjorck(full_config.get('passthrough', False))
    set_stop_grad_spectral(full_config.get('stop_grad_spectral', False))
    deel.lip.extra_layers.set_init(full_config.get('batch_init', False))
    deel.lip.extra_layers.set_avg(full_config.get('batch_avg', False))
    deel.lip.extra_layers.set_center(full_config.get('batch_center', False))
    deel.lip.extra_layers.set_alpha(full_config.get('batch_alpha', 0.99))
    deel.lip.extra_layers.set_stop_gradient(full_config.get('batch_grad', False))
    if verbose:
        print("STOP_GRAD_SPECTRAL ",deel.lip.normalizers.STOP_GRAD_SPECTRAL)
        print("GRAD_PASSTHROUGH_BJORCK ",deel.lip.normalizers.GRAD_PASSTHROUGH_BJORCK)
        print("DEFAULT_ALPHA ",deel.lip.extra_layers.DEFAULT_ALPHA)
        print("DEFAULT_INIT ",deel.lip.extra_layers.DEFAULT_INIT)
        print("DEFAULT_AVG ",deel.lip.extra_layers.DEFAULT_AVG)
        print("DEFAULT_CENTER ",deel.lip.extra_layers.DEFAULT_CENTER)

def explain_for_dummies(model, x,y,x_rescale = None,counter_coeff = 0.2,div = 255,filename = None,class_counter = -1, post_process = img_int_to_float,names = None):
    image = tf.Variable(np.expand_dims(x, axis=0))
    prediction=model(image,training = False).numpy().squeeze()
    class_found = np.argmax(prediction)
    val_class = prediction[class_found]
    prediction[class_found] = -1000000
    if class_counter == -1:
        class_counter = np.argmax(prediction)
    with tf.GradientTape() as tape:
        tape.watch(image)
        prediction=model(image,training = False)
        v_0 = prediction[0,class_found]
    grads_class = tape.gradient(v_0, image).numpy()
    with tf.GradientTape() as tape:
        tape.watch(image)
        prediction=model(image,training = False)
        v_1 = prediction[0,class_counter]
    grads_counter = tape.gradient(v_1, image).numpy()
    
    grads_class=grads_class.reshape(x.shape)
    grads_counter=grads_counter.reshape(x.shape)
    if x_rescale is not None :
        x = x_rescale
    fig, axs = plt.subplots(1, 3,figsize=(10, 12))
    
    img_exp =  x+counter_coeff*grads_counter/(np.linalg.norm(grads_counter))*np.linalg.norm(x)
    if post_process is not None :
        x = post_process(x)
        img_exp = post_process(img_exp)
    axs[0].imshow(x)    
    #axs[1].imshow( rgb_grad(grads_class))
    axs[1].imshow( grad_to_color2(grads_counter))
    axs[2].imshow(img_exp)
    axs[0].set_axis_off()
    axs[1].set_axis_off()
    axs[2].set_axis_off()
    #axs[3].set_axis_off()
    if filename is not None :
        if names is not None :
            plt.savefig(filename+"_"+names[class_found]+"_to_"+names[class_counter]+".jpg",bbox_inches='tight')
        else:
            plt.savefig(filename+"_"+str(class_found)+"_to_"+str(class_counter)+".jpg",bbox_inches='tight')
        plt.close(fig)
    else :
        if names is not None:
            print(names[class_found]+" to "+names[class_counter])
        else :
            print(str(class_found)+" to "+str(class_counter))
    return class_found,class_counter

def mesure_metric(model,X,Y,expe_name,metric,explainers, nb = 1000):
    
    f = open(expe_name, "w")
    for explainer in  explainers:
        explainer_name = explainer.__class__.__name__
        print("compute explanation : ",explainer_name)
        explanations = explainer(X[:nb], Y[:nb])
        if len(explanations.shape) > 3:
            explanations = np.mean(explanations, -1)
        
        print("compute metric : ",explainer_name)
        start = time.time()
        
        score = metric(explanations)
        end = time.time()
        print("computation time :",(end - start),f" {score:.4f}" )
        f.write(f"{explainer_name} : {score:.4f}\n")
    f.close()

def mesure_stability(model,X,Y,expe_name,metric,explainers, nb = 1000):
    
    f = open(expe_name, "w")
    for explainer in  explainers:
        explainer_name = explainer.__class__.__name__
        print("compute metric : ",explainer_name)
        start = time.time()
        
        score = metric(explainer)
        end = time.time()
        print("computation time :",(end - start),f" {score:.4f}" )
        f.write(f"{explainer_name} : {score:.4f}\n")
    f.close()

def get_data(data, nb_images = 1000, batch_size = 256):
  X = None
  Y = None
  dataset_it = data.__iter__()
  for i in range(nb_images//batch_size+1):
    X_tmp,Y_tmp = next(dataset_it)
    if X is None:
      X = X_tmp.numpy()
      Y = Y_tmp.numpy()
    else : 
      X = np.concatenate((X, X_tmp), axis=0)
      Y = np.concatenate((Y, Y_tmp), axis=0)
    print(X.shape)
  return X,Y

def evaluate_dist(model,data,expe_name, nb_images = 10):
    X,Y = get_data(data,nb_images = nb_images*2)
    x_max, x_min = X.max(), X.min()
    sal = Saliency(model)
    phi = sal(X, Y).numpy()
    
    smooth = SmoothGrad(model, nb_samples=50, batch_size=256, noise = 0.1*(x_max-x_min))
    phi2 = sal(X, Y).numpy()
    mean_sal = (np.abs(phi)).mean()
    print("mean :",mean_sal)
    l2_dist = []
    for i in range(nb_images):
        l2_dist.append(np.mean(np.sqrt(np.square(phi - phi2))))
    print("Mean l2", np.mean(l2_dist)/mean_sal)



def evaluate_kolmo(model,data,expe_name, nb_images = 10):
    X,Y = get_data(data,nb_images = nb_images*2)
    explainer = Saliency(model)
    sizes = []
    for i in range(nb_images):
        phi = explainer([X[i]], [Y[i]]).numpy()
        #print(phi.max())
        phi = phi.reshape((phi.shape[1],phi.shape[2]))
        phi -= phi.min()
        phi /= phi.max()
        phi *= 255.0
        phi = np.array(phi).astype(np.uint8)
        cv2.imwrite('img.jpg', phi)
        sz = os.path.getsize("img.jpg") 
        sizes.append(sz)
    print('mean bytes size', np.mean(sizes), "std", np.std(sizes))
    print('mean kilo-bytes', np.mean(sizes) / 1_000, "std", np.std(sizes) / 1_000)