import os
import shutil
import sys
import time
sys.path.append('./')
sys.path.append('../')
os.environ['TF_CPP_MIN_LOG_LEVEL']='3'
#os.environ['AUTOGRAPH_VERBOSITY'] = '1'
import tensorflow as tf
import tensorflow_datasets as tfds
tf.get_logger().setLevel('ERROR')
from tensorflow.keras.applications.resnet50 import ResNet50
import numpy as np
import pickle
import cv2
from deel.datasets.imagenet_dataset import imagenet_dataset
from deel.utils.yaml_to_params import load_yaml_config,getParams, getFunctionFromModules, dumdict2yaml
from deel.utils.yaml_loader import load_model, loadFunctionList
from deel.utils.yaml_loader import load_optimizer_and_loss
from xplique.attributions import (Saliency, GradientInput, IntegratedGradients, SmoothGrad, VarGrad,
                                  SquareGrad, GradCAM, Occlusion, Rise, GuidedBackprop,SobolAttributionMethod,
                                  GradCAMPP, Lime, KernelShap)

from matplotlib import pyplot as plt
from xplique.plots import plot_attributions,plot_attribution
from deel.utils.lip_utils import explain_for_dummies, load_compiled_model,add_softmax,get_data, mesure_metric,  spearman_dist, mesure_stability,evaluate_kolmo,evaluate_dist
from deel.datasets.imagenet_dataset import normalize_vgg,normalize
from harmonization.common import load_clickme_val
from harmonization.evaluation import evaluate_clickme
from xplique.attributions.base import WhiteBoxExplainer, sanitize_input_output
from xplique.commons import batch_gradient
from xplique.types import Optional, Union
from tensorflow.keras.applications.resnet50 import preprocess_input
from xplique.metrics import MuFidelity,AverageStability



class SaliencyInv(Saliency):
    def explain(self,inputs,targets):
      #print(inputs)
      inputs = np.array(inputs)
      #inputs = inputs[..., [2,1,0]]
      inputs = preprocess_input(inputs)
      return super().explain( inputs,targets)

def preprocess_images(x):
  #x = tf.reverse(x,axis = -1)
  #print(x.shape)
  return normalize_vgg(x)

def inverse_process(x) :
    mean = [123.68, 116.779,103.939 ]
    #x = x[..., [2,1,0]]
    x[..., 0] += mean[2]
    x[..., 1] += mean[1]
    x[..., 2] += mean[1]
    
    #x = x[..., [2,1,0]]
    x = np.clip(x, a_min=0, a_max=255)
    return x
def data_range(data):
  dataset_it = data.__iter__()
  print("ok1")
  x,y = next(dataset_it)
  print("ok2")
  x = x.numpy()
  y = y.numpy()
  print(x.shape)
  print(x.min(), x.mean(), x.max())
def data_stats(data,prefix = "val"):
    dataset_it = data.__iter__()
    print("ok1")
    x,y = next(dataset_it)
    print("ok2")
    x = x.numpy()
    y = y.numpy()
    print(x.shape)
    print(x.mean(axis=tuple(range(x.ndim-1))),x.min(),x.max(),x.std(axis=tuple(range(x.ndim-1))) )
    #print("ymin", y.min(),"ymax", y.max())
    x = inverse_process(x)
    for i in range(5):
        img = x[i]
        #print(img.min(), img.max())
        cv2.imwrite("images/"+prefix+"_"+str(i)+".jpeg",img)



def evaluate_results(models, data):
    dataset_it = data.__iter__()
    for i in range(500):
        x,y = next(dataset_it)
        logits = model(x, training=True)
        y = y.numpy()
        logits = logits.numpy()
        res = logits[np.argmax(y,axis = 1) == np.argmax(logits,axis = 1)]
        print(res.shape[0]/logits.shape[0],logits.max(axis=1).mean())
def save_gradcam(model,data,expe_name,nb_images=10):
  dataset_it = data.__iter__()
  rep = "./results/"+expe_name+"/curves/smoothgrad/"
  if not os.path.exists(rep):
    os.makedirs(rep)
  X,Y = next(dataset_it)
  x_max, x_min = X.numpy().max(), X.numpy().min()
  nb = int(X.shape[0]//nb_images)
  explainer = SmoothGrad(model,nb_samples=100, batch_size=100, noise = 0.1*(x_max-x_min))
  for i in range(nb):
    print("Saliency n°",i)
    deb =i*nb_images
    end =deb + nb_images
    x = X[deb:end]
    y = Y[deb:end]
    explanations = explainer.explain(x.numpy(), y.numpy())
    if len(explanations.shape)>3 :
      explanations = np.mean(explanations, axis = 3)
    x_rescale = inverse_process(x.numpy())
    plot_attributions(explanations, x_rescale, cmap='jet', alpha=0.4,cols=nb_images,  absolute_value=True,clip_percentile=0.)
    plt.savefig(rep+"gradcam_"+str(i)+".jpg",bbox_inches='tight')
    plt.close()
def save_saliency(model,data,expe_name,nb_images=10):
  dataset_it = data.__iter__()
  rep = "./results/"+expe_name+"/curves/saliency/"
  if not os.path.exists(rep):
    os.makedirs(rep)
  X,Y = next(dataset_it)
  nb = int(X.shape[0]//nb_images)
  explainer = Saliency(model)
  for i in range(nb):
    print("Saliency n°",i)
    deb =i*nb_images
    end =deb + nb_images
    x = X[deb:end]
    y = Y[deb:end]
    explanations = explainer.explain(x.numpy(), y.numpy())
    x_rescale = inverse_process(x.numpy())
    if(nb_images ==1):
        plot_attribution(explanations[0,:,:], x_rescale[0,:,:,:], cmap='jet', alpha=0.5,  absolute_value=True,clip_percentile=0)
    else:
        plot_attributions(explanations, x_rescale, cmap='jet', alpha=0.5,cols=nb_images,  absolute_value=True,clip_percentile=0)
    plt.savefig(rep+"saliency_"+str(i)+".jpg",bbox_inches='tight',pad_inches=0)
    plt.close()
def save_saliency_combine(model,data,expe_name,nb_images=10):
  dataset_it = data.__iter__()
  rep = "./results/"+expe_name+"/curves/grads_comb/"
  if not os.path.exists(rep):
    os.makedirs(rep)
  X,Y = next(dataset_it)
  nb = int(X.shape[0]//nb_images)
  explainer = SaliencyCombine(model)
  for i in range(nb):
    print("Saliency n°",i)
    deb =i*nb_images
    end =deb + nb_images
    x = X[deb:end]
    y = Y[deb:end]
    explanations = explainer.explain(x.numpy(), y.numpy())
    x_rescale = inverse_process(x.numpy())
    plot_attributions(explanations, x_rescale, cmap='jet', alpha=0.4,cols=nb_images,  absolute_value=True)
    plt.savefig(rep+"saliency_comb_"+str(i)+".jpg",bbox_inches='tight')
    plt.close()

def save_saliency_l2(model,data,expe_name,nb_images=10):
  dataset_it = data.__iter__()
  rep = "./results/"+expe_name+"/curves/grads_l2/"
  if not os.path.exists(rep):
    os.makedirs(rep)
  X,Y = next(dataset_it)
  nb = int(X.shape[0]//nb_images)
  explainer = SaliencyL2(model)
  for i in range(nb):
    print("Saliency n°",i)
    deb =i*nb_images
    end =deb + nb_images
    x = X[deb:end]
    y = Y[deb:end]
    explanations = explainer.explain(x.numpy(), y.numpy())
    x_rescale = inverse_process(x.numpy())
    plot_attributions(explanations, x_rescale, cmap='jet', alpha=0.4,cols=nb_images,  absolute_value=True)
    plt.savefig(rep+"saliency_"+str(i)+".jpg",bbox_inches='tight')
    plt.close()
def save_gradient(model,data,expe_name, nb_images = 10):
    rep = "./results/"+expe_name+"/curves/grads/"
    if not os.path.exists(rep):
        os.makedirs(rep)
    ind = 0
    dataset_it = data.__iter__()

    X,Y = next(dataset_it)
    X_rescale = inverse_process(X.numpy())
    for i in range(min(nb_images,X.shape[0])):
      print("computing gradient n°",i)
      cl,count = explain_for_dummies(model,X[i],Y[i],x_rescale =X_rescale[i], counter_coeff = 0.15,filename = rep  +"grad_"+str(i))

def evaluate_mu_unif(model,data,expe_name):
    X,Y = get_data(data,nb_images = 2000)
    x_max, x_min = X.max(), X.min()
    print(X.shape)
    nb = 1000
    #model.layers[-1].activation = tf.keras.activations.linear
    baseline_uniform = lambda x: np.random.uniform(size=x.shape) * (x_max-x_min) - x_min 
    metric = MuFidelity(model, X[:nb], Y[:nb],  batch_size, grid_size = 9,nb_samples=50,baseline_mode = baseline_uniform)
    explainers = [
                Saliency(model), 
                SmoothGrad(model, nb_samples=50, batch_size=batch_size, noise = 0.1*(x_max-x_min)),
                #Rise(model, nb_samples=400, batch_size=batch_size),
                #ntegratedGradients(model, steps=50, batch_size=batch_size),
                GradientInput(model),
                GradCAM(model)
                 ]
    mesure_metric(model,X,Y, "./results/"+expe_name+"/curves/mufidelity.txt",metric,explainers,nb = nb)

def evaluate_mu_zero(model,data,expe_name):
    X,Y = get_data(data,nb_images = 2000)
    x_max, x_min = X.max(), X.min()
    print(X.shape)
    nb = 1000
    #model.layers[-1].activation = tf.keras.activations.linear
    baseline_uniform = lambda x: np.random.uniform(size=x.shape) * (x_max-x_min) - x_min 
    metric = MuFidelity(model, X[:nb], Y[:nb],  batch_size, grid_size = 9,nb_samples=50)
    explainers = [
                Saliency(model), 
                SmoothGrad(model, nb_samples=50, batch_size=batch_size, noise = 0.1*(x_max-x_min)),
                #Rise(model, nb_samples=400, batch_size=batch_size),
                IntegratedGradients(model, steps=50, batch_size=batch_size),
                GradientInput(model),
                GradCAM(model)
                 ]
    mesure_metric(model,X,Y, "./results/"+expe_name+"/curves/mufidelity_zero.txt",metric,explainers,nb = nb)

def evaluate_stability(model,data,expe_name):
    X,Y = get_data(data,nb_images = 1000)
    x_max, x_min = X.max(), X.min()
    nb = 1000
    print(f"data range : [{x_min},{x_max}]")
    #model.layers[-1].activation = tf.keras.activations.linear
    metric = AverageStability(model, X[:nb], Y[:nb],  batch_size,  radius = .15*x_max, distance = spearman_dist, nb_samples = 10)
    explainers = [
                Saliency(model), 
                SmoothGrad(model, nb_samples=50, batch_size=batch_size),
                IntegratedGradients(model, steps=50, batch_size=batch_size)
                 ]
    mesure_stability(model,X,Y, "./results/"+expe_name+"/curves/stability_spear.txt",metric,explainers,nb = nb) 

clickme_dataset = load_clickme_val(batch_size = 128)
batch_size = 256

train, val,train_val, info = imagenet_dataset(batch_size = batch_size,
                                              preprocess = preprocess_input,
                                              compute_train_val=True,
                                              shuffle_files = False,
                                              shuffle = 0)
model = ResNet50(weights='imagenet')
model.compile(loss='categorical_crossentropy',
                  optimizer='adam',
                  metrics=['categorical_accuracy'])
model.summary()
#results = model.evaluate(val, steps = 100000//batch_size)
expe_name = "ResNet50"
explainer = SaliencyInv(model)
rep = "./results/"+expe_name+"/models/"
#evaluate_kolmo(model,val,expe_name,nb_images=10)
#evaluate_dist(model,val,expe_name,nb_images=100)
#evaluate_mu_unif(model,val,expe_name)
#evaluate_mu_zero(model,val,expe_name)
# X,Y = get_data(val,nb_images = 2000)
# print(X.shape)
# nb = 2000
# #model.layers[-1].activation = tf.keras.activations.linear
# x_max, x_min = X.max(), X.min()
# baseline_zero    = 0.0
# baseline_uniform = lambda x: np.random.uniform(size=x.shape) * (x_max-x_min) - x_min
# metric = MuFidelity(model, X[:nb], Y[:nb],  batch_size, grid_size = 9,nb_samples=50,baseline_mode = baseline_uniform )
# explainers = [Saliency(model),
#               SmoothGrad(model, nb_samples=50, batch_size=batch_size),
#               Rise(model, nb_samples=4000, batch_size=batch_size),
#               IntegratedGradients(model, steps=50, batch_size=batch_size),
#               GradientInput(model),
#               GradCAM(model)]
# mesure_metric(model,X,Y, "./results/"+expe_name+"/curves/mufidelityuniform.txt",metric,explainers,nb = nb)
#model.save(rep+expe_name+"vanilla.h5")



#
#vanilla = model.vanilla_export()
#model = tf.keras.models.load_model("results/resnet_50_save/models/resnet_50_save_vanilla/")
#model.compile(loss='categorical_crossentropy',
#                optimizer='adam',
#                metrics=['categorical_accuracy','top_k_categorical_accuracy'])
#results = model.evaluate(val, steps = 100000//batch_size)



#model_soft = add_softmax(model)
#save_gradcam(model,val,expe_name,nb_images=1)
#save_gradient(model,val,expe_name,nb_images=100)
#model_soft.summary()
#save_saliency(model,val,expe_name,nb_images=1)
#save_saliency_combine(model,val,expe_name,nb_images=10)
# scores = evaluate_clickme(model, 
#                           explainer = explainer,
#                           preprocess_inputs=None)
# print('alignement :',scores['alignment_score'])
# f = open("./results/"+expe_name+"/curves/aligement.txt", "w")
# f.write(f"alignement :{scores['alignment_score']:0.3f}")
# f.write('\n')
# f.flush()
# f.close()

#save_gradcam(model_soft,val,expe_name,nb_images=10)

#save_gradient(model,val,expe_name,nb_images=256)
#print("train")
#evaluate_results(model, train)
#print("val")
#evaluate_results(model, val)
#print("trainval")
#evaluate_results(model, val)
#print(results)
#nb = 1000,nb_samples=[128], grid_size=[9])