
import os
#os.environ["CUDA_VISIBLE_DEVICES"]=""

import argparse
import sys
print(sys.path)
import pandas as pd
import time
import tqdm

try:
    import wandb
except ImportError:
    wandb = None

import tensorflow as tf
import tensorflow.keras.layers as layers
import tensorflow.keras.losses as losses
import tensorflow.keras.optimizers as optimizers
import dlt.data.loader as loader
import dlt.data.pipeline as pipeline
import dlt.data.augmentation as aug
import dlt.infrastructure.distributed_training as distributed
from dlt.extras.layers import skip_connections as skips
from dlt.model_factory import *
from dlt.model_factory.utils import ClassParam

from deel.lip.model import vanillaModel
from deel import lip as deellip

from sklearn.metrics import multilabel_confusion_matrix

import numpy as np

from deel.custom_losses import HKR as HKR_multilabel

import utils_train

def loss_multilabel_cross_entropy_with_logits():
    def loss(y_true,y_pred):
        cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=y_pred, labels=tf.cast(y_true,tf.float32))
        '''tf.print(cross_entropy.shape)
        tf.print(y_pred)
        tf.print(tf.sigmoid(y_pred))
        tf.print(y_true)
        tf.print(cross_entropy)'''
        loss_crossent = tf.reduce_mean(tf.reduce_sum(cross_entropy, axis=1))
        return loss_crossent
    return loss
class AccuracyMultilabelWithLogits(tf.keras.metrics.Mean):
    def __init__(self, name='acc_avg_multilabel', thresh = 0.5, **kwargs):
        super(AccuracyMultilabelWithLogits, self).__init__(name=name, **kwargs)
        self.correct_label = self.add_weight(name='tl', initializer='zeros')
        self.thresh = thresh
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.cast(y_true, tf.bool)
        if self.thresh == 0.5:
            y_mlbl = tf.cast(y_pred>0., tf.bool) ##sigmoid useless
        else:
            prediction = tf.sigmoid(y_pred)
            y_mlbl = tf.cast(prediction>self.thresh, tf.bool)
        '''tf.print(y_pred)
        tf.print(prediction)
        tf.print(y_true)'''
        compare = tf.cast(tf.equal(y_mlbl , y_true), tf.float32)
        
        '''tf.print(compare)
        tf.print(tf.reduce_mean(compare, axis = 0))
        tf.print(tf.reduce_mean(compare))
        tf.print(tf.reduce_sum(compare))'''
        super(AccuracyMultilabelWithLogits, self).update_state(compare, sample_weight=None)
        '''tf.print(self.total)
        tf.print(self.count)'''
        
        #acc_perlabel = tf.reduce_sum(compare, axis = 0)
        #self.correct_label.assign_add(tf.reduce_mean(acc_perlabel))
    #def result(self):
    #    return self.correct_label

def accuracy_multilabel_with_logits(thresh = 0.5):
    def acc(y_true,y_pred):
        if thresh == 0.5:
            output = tf.cast(y_pred > 0., tf.int32)
        else:
            prediction = tf.sigmoid(y_pred)
            tf.print("pred ",prediction.shape,prediction[0:1])
            output = tf.cast(prediction > thresh, tf.int32)
        compare = tf.cast(tf.equal(output , y_true), tf.float32)
        acc_perlabel = tf.reduce_mean(compare, axis = 0)
        return tf.reduce_mean(acc_perlabel)

def compute_confusion_matrix(model, ds_test, thresh = 0.5):
    y_true = []
    y_pred = []
    for x,y in ds_test:
        o = model.predict(x)
        y_true.append(y)
        if thresh == 0.5:
            y_pred.append(o>0.)
        else:
            y_pred.append(tf.sigmoid(o)>0.5)
    y_true = tf.concat(y_true,axis=0).numpy()
    y_pred = tf.concat(y_pred,axis=0).numpy()
    return multilabel_confusion_matrix(y_true,y_pred)

def compute_metrics(conf_matrix, txt_file=None):
    np.set_printoptions(suppress=True)
    conf_matrix_perc = conf_matrix/np.sum(conf_matrix, axis=(1,2),keepdims=True)
    print("-----------------------------\nPercent [[TN,FP],[FN,TP]]", file=txt_file)
    print(conf_matrix_perc, file=txt_file)
    accuracy = np.trace(conf_matrix,axis1=1,axis2=2)/np.sum(conf_matrix, axis=(1,2))
    print("-----------------------------\nAccuracy per attributes", file=txt_file)
    print(accuracy, file=txt_file)
    print("-----------------------------\nAvg Accuracy", file=txt_file)
    print(np.mean(accuracy), file=txt_file)

    print("-----------------------------\Precision per attributes", file=txt_file)
    precision = conf_matrix[:,1,1]/np.sum(conf_matrix[:,:,1],axis=1)
    print(precision, file=txt_file)
    print("-----------------------------\nAvg precision", file=txt_file)
    print(np.mean(precision), file=txt_file)
    print("-----------------------------\Recall/Sensibility per attributes", file=txt_file)
    recall = conf_matrix[:,1,1]/np.sum(conf_matrix[:,1,:],axis=1)
    print(recall, file=txt_file)
    print("-----------------------------\nAvg Recall/Sensibility", file=txt_file)
    print(np.mean(recall), file=txt_file)
    print("-----------------------------\Specificity per attributes", file=txt_file)
    specificity = conf_matrix[:,0,0]/np.sum(conf_matrix[:,0,:],axis=1)
    print(specificity, file=txt_file)
    print("-----------------------------\nAvg Specificity", file=txt_file)
    print(np.mean(specificity), file=txt_file)


def parse_training_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_type", type=str, help="Model type")
    parser.add_argument("-bs", "--batch_size", type=int, help="Batch size")
    parser.add_argument("-e", "--epochs", type=int, help="Number of epochs")
    parser.add_argument("-lr", "--learning_rate", type=float, help="Learning rate")
    parser.add_argument("-is", "--input_scaling", type=float, help="Input scaling", default=255.0)
    parser.add_argument("--loss", type=str, default="HKR", help="Loss")
    parser.add_argument("--hinge_type", type=str, help="hinge type Hinge/HingeBalanced/HingeVar")
    parser.add_argument("--alphaHKR", type=float, help="alpha for HKR, -1 for inf")
    parser.add_argument("--min_margin", type=float, help="min_margin for HKR")
    parser.add_argument("--tau", type=float, help="tau for tauCCE")
    parser.add_argument("--perc", type=float, help="perc for HKRauto")
    parser.add_argument("--wandb", action="store_true", help="Enable wandb")
    parser.add_argument("--save", action="store_true", help="Save results")
    parser.add_argument("--classes", nargs="+", type=int, help="list of selected classes")

    args = parser.parse_args()
    if args.alphaHKR == -1:
        args.alphaHKR = float("inf")
    print("-------- Training arguments --------")
    for arg in vars(args):
        print(f"- {arg}: {getattr(args, arg)}")
    return args





@tf.function
def train_step(x, y, model, loss_fn, t, optimizer, optim_marg, optim_margin=False,metric = None):
    with tf.GradientTape() as w_tape:
        logits = model(x, training=True)
        loss_value = loss_fn(y, logits)
        final_loss = loss_value

    if optim_margin:
        grads = w_tape.gradient(final_loss, model.trainable_weights + [t])
        grad_t = grads[-1:]
        grads = grads[:-1]
        #grads = [ tf.clip_by_value(g, clip_value_min=-0.1, clip_value_max=0.1) for g in grads]
    else:
        grads = w_tape.gradient(final_loss, model.trainable_weights)

    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    if optim_margin:
        optim_marg.apply_gradients(zip(grad_t, [t]))
        t.assign(tf.clip_by_value(t, clip_value_min=loss_fn.min_margin, clip_value_max=200))

    acc = metric(y, logits)
    #tf.print("loss value and grads",loss_value,gradient_norm(grads),output_stream=sys.stdout)
    return loss_value, acc#, gradient_norm(grads),regul


def fit_constraints(model, train,validation, loss_fct, optimizer, steps_per_epoch=50,
                    validation_step=50,
                    callbacks=[],
                    epochs=20,
                    verbose=2,
                    optim_margin=False,
                    #margin_only=False,
                    margin_lr = 1.e-4,
                    metric = None
                    #optim_prox = None,
                   #lambda_orth = 0
                   ):
    for c in callbacks:
        c.set_model(model)
    train_it = train.__iter__()
    val_it = validation.__iter__()
    optim_marg = optimizers.Adam(learning_rate=margin_lr)
    #model_vars = model.trainable_variables
    #compute_singular(model_vars, one=True)
    logs = {}
    for c in callbacks:
        c.on_train_begin(logs=logs)
    loss = tf.metrics.Mean()
    regul = tf.metrics.Mean()
    cat_acc = tf.metrics.Mean()
    val_loss = tf.metrics.Mean()
    val_cat_acc = tf.metrics.Mean()
    g_norm = tf.metrics.Mean()
    for e in range(epochs):
        start_time = time.time()
        #tf.print('epoch :',e,output_stream=sys.stdout)
        for c in callbacks:
            c.on_epoch_begin(e, logs=None)

        
        #for batch in range(steps_per_epoch):
        for batch,(x,y) in enumerate(train):
            #x, y = next(train_it)
            for c in callbacks:
                c.on_batch_begin(batch, logs=None)
                #c.on_train_batch_begin(batch, logs=None)
            #
            if optim_margin:
                marg = loss_fct.hingeloss.margins
            else:
                marg = None
            loss_value,  acc = train_step(x, y, model, loss_fct, marg, optimizer,
                                                            optim_marg, optim_margin=optim_margin, metric = metric)

            logs = {"loss": loss_value.numpy(),
                    "categorical_accuracy": acc.numpy().mean(),
                    #"grad_norm": g_n.numpy(),
                   #"regul": reg.numpy()
            }
            #print(e, batch,"/",steps_per_epoch,"  ",logs)
            for c in callbacks:
                #c.on_train_batch_end(batch, logs=logs)
                c.on_batch_end(batch, logs=logs)

            loss.update_state(logs["loss"])
            cat_acc.update_state(logs["categorical_accuracy"])
            #regul.update_state(logs["regul"])
            #g_norm.update_state(logs["grad_norm"])
            print(batch,logs, end='\r')

        for batch,(x,y) in enumerate(validation):
            #x, y = next(val_it)
            logits = model(x, training=False)
            loss_value = loss_fct(y, logits)
            acc = metric(y, logits)
            val_loss.update_state(loss_value.numpy())
            val_cat_acc.update_state(acc.numpy().mean())

            
        #apply_constraints(model)
        total_time =time.time() - start_time
        logs = {"loss": loss.result(),
                "acc": cat_acc.result(),
                "val_loss": val_loss.result(),
                "val_acc": val_cat_acc.result(),
                "grad_norm": g_norm.result(),
                "regul":  regul.result(),
                'time' :total_time }
        
        for c in callbacks:
            c.on_epoch_end(e, logs=logs)
            
            
        print(f"Epoch {e + 1}/{epochs}")
        print(
            f"time : {total_time:.2f}s *** loss:{loss.result():0.2f} val_loss:{val_loss.result():0.2f}  acc:{cat_acc.result() * 100:0.1f}% val_acc:{val_cat_acc.result() * 100:0.1f}%  ")
        
        if optim_margin:
            print("min m", tf.reduce_min(loss_fct.hingeloss.margins).numpy(),
              "mean m", tf.reduce_mean(loss_fct.hingeloss.margins).numpy(),"max m", tf.reduce_max(loss_fct.hingeloss.margins).numpy())
        loss.reset_states()
        cat_acc.reset_states()
        val_loss.reset_states()
        val_cat_acc.reset_states()
        g_norm.reset_states()
        regul.reset_states()
        sys.stdout.flush()

    for c in callbacks:
        c.on_train_end(logs=logs)




if __name__ == "__main__":

    
    args = parse_training_args()
    wandb_active, save_active = args.wandb, args.save
    del args.wandb, args.save

    if wandb is None:
        wandb_active = False
    if wandb_active:
        while True:
            try:
                wandb.init(entity="xxxxxxxx", project = 'celebA_multilabel', config=args)        
                break
            except:
                print("Retrying wandb")
                time.sleep(10)
        
    ###########################################################################
    # load data
    ###########################################################################
    ds_train, ds_test, metadata = loader.get_celeb_a_multilabel()

    
    '''lbl_sum = tf.zeros((40,))
    nb_sample = 0
    for x,y in ds_train:
        if len(y.shape)>1:
            lbl_sum += tf.reduce_sum(y,axis=0)
            nb_sample += len(y)
        else:
            lbl_sum += y
            nb_sample += 1

    print(lbl_sum/nb_sample)
    train_stats = (lbl_sum/nb_sample).numpy()    
    lbl_sum = tf.zeros((40,))
    nb_sample = 0
    for x,y in ds_test:
        if len(y.shape)>1:
            lbl_sum += tf.reduce_sum(y,axis=0)
            nb_sample += len(y)
        else:
            lbl_sum += y
            nb_sample += 1

    print(lbl_sum/nb_sample)
    test_stats = (lbl_sum/nb_sample).numpy() 
    stats = pd.DataFrame([train_stats,test_stats],columns=metadata["class_names"],index=['train','test'])
    for ll in range(0,40,6):
        print(stats[stats.columns[ll:ll+6]])'''
    x,y = next(iter(ds_train))
    print(x.shape)
    print(tf.reduce_min(x))
    print(tf.reduce_max(x))
    print(tf.reduce_mean(x))
    print(y.shape)
    print(y)
    print("All classes :", metadata["class_names"])
    
    if args.classes is None:
        nb_classes = len(metadata["class_names"])
        all_selected_classes = np.ones((nb_classes,)).astype(np.bool)
    else:
        print(list(args.classes))
        nb_classes = len(set(args.classes))
        one_hot = np.eye(len(metadata["class_names"]))[args.classes]
        all_selected_classes = np.sum(one_hot,axis=0).astype(np.bool)
    print(all_selected_classes)
    print("nb_classes ", nb_classes)
    print("Selected classes :", np.asarray(metadata["class_names"])[all_selected_classes])
    
    if nb_classes==1:
        label = np.asarray(metadata["class_names"])[all_selected_classes]
        suffix_save = "_"+label[0]+"_"+str(args.model_type)+"_"+str(args.hinge_type)+"_alpha"+str(args.alphaHKR)+"_mm"+str(args.min_margin)+"_epoch"+str(args.epochs)
    else:    
        suffix_save = "_"+str(nb_classes)+"labels_"+str(args.model_type)+"_"+str(args.hinge_type)+"_alpha"+str(args.alphaHKR)+"_mm"+str(args.min_margin)+"_epoch"+str(args.epochs)

    print(suffix_save)
    
    norm_factor = args.input_scaling #255.0 #1.0
    batch_size = args.batch_size
    feat_factor = 4
    smallvgg16_structure = dict(
        conv_sizes=(
            (feat_factor*4, feat_factor*4),
            (feat_factor*8, feat_factor*8),
            (feat_factor*16, feat_factor*16, feat_factor*16),
            (feat_factor*32, feat_factor*32, feat_factor*32),
            (feat_factor*32, feat_factor*32, feat_factor*32),
        ),
        dense_sizes=(feat_factor*64, feat_factor*64),
        name="smallVGG16",
    )

    ortho_layers_params = dict(
        conv=ClassParam(deellip.layers.OrthoConv2D, kernel_size=(3, 3), padding="circular", use_bias=True, regul_lorth=0),
        dense=ClassParam(deellip.layers.SpectralDense, use_bias=True),
        last_dense=ClassParam(deellip.layers.FrobeniusDense, disjoint_neurons=True, use_bias=False),
        pooling=None, #ClassParam(ScaledL2NormPooling2D, pool_size=(2, 2)),
        activation=deellip.activations.GroupSort2,
        normalization=None,
        dropout=None
    )

    layers_params = vgg.lip_layers_params
    if args.model_type == "ortho":
        layers_params = ortho_layers_params

    ###########################################################################
    # data preparation and augmentation
    ###########################################################################
    ds_train, ds_test = pipeline.prepare_data(
        ds_train,
        ds_test,
        preparation_x=[
            lambda x: tf.cast(x, dtype=tf.float32) / norm_factor,
        ],
        preparation_y=[
            lambda y: y[all_selected_classes]
            #lambda y: tf.one_hot(tf.cast(y, dtype=tf.int32), depth=metadata["nb_classes"])
        ],
        augmentation_x=[
            aug.aug_layer(aug.random_resizedcrop, scale=(0.5, 1.0), ratio=(3 / 4, 4 / 3)),
            tf.image.random_flip_left_right,
            aug.aug_layer(aug.random_brightness, max_delta=0.2),
            aug.aug_layer(aug.random_contrast, lower=0.6, upper=1.4),
            aug.aug_layer(aug.cutout, pad_size=2),
        ],
        batch_size=batch_size,
    )
    
    ###########################################################################
    # build model
    ###########################################################################
    #loss = loss_multilabel_cross_entropy_with_logits()
    loss = HKR_multilabel(alpha = args.alphaHKR,  min_margin=args.min_margin, nb_class = nb_classes,hinge_type = args.hinge_type)
    optim_model = optimizers.Adam(learning_rate=args.learning_rate)
    strategy = distributed.get_distribution_strategy()
    with strategy.scope():
        kwargs = {}
        kwargs.update(layers_params) #binary_layers_params)
        kwargs.update(smallvgg16_structure)
        # kwargs["patch_size"] = 2
        model = vgg.VGG(metadata["input_shape"], nb_classes, **kwargs)
        model.compile(
            loss=loss, #losses.CategoricalCrossentropy(from_logits=True),
            metrics=[AccuracyMultilabelWithLogits(),"accuracy"],
            optimizer=optim_model,
        )
    #tf.keras.utils.plot_model(model, show_shapes=True)
    model.summary()
    
    callbacks = []
    if wandb_active:
        callbacks += [wandb.keras.WandbCallback()]
        callbacks += [utils_train.WandbLRLogger(model.optimizer)]

    ###########################################################################
    # fit model
    ###########################################################################
    #model.fit(ds_train, validation_data=ds_test, epochs=args.epochs,callbacks=callbacks)
    optim_margin = args.hinge_type == "HingeVar"
    fit_constraints(model, ds_train,ds_test, loss, optim_model, 
                    steps_per_epoch=metadata["nb_samples_train"]//batch_size,
                    validation_step=metadata["nb_samples_test"]//batch_size,
                    callbacks=callbacks,
                    epochs=args.epochs,
                    verbose=2,
                    optim_margin=optim_margin,
                    #margin_only=False,
                    margin_lr = 1.e-4,
                    metric = AccuracyMultilabelWithLogits()
                    #optim_prox = None,
                   #lambda_orth = 0
                   )
    conf_matrix =  compute_confusion_matrix(model,ds_test)
    
    np.set_printoptions(suppress=True)
    #print(conf_matrix)
    #print(np.trace(conf_matrix,axis1=1,axis2=2)/np.sum(conf_matrix, axis=(1,2)))
    #print(np.mean(np.trace(conf_matrix,axis1=1,axis2=2)/np.sum(conf_matrix, axis=(1,2))))
    

    path_output = "outputs"
    with open(os.path.join(path_output, "confusion_matrix_"+str(feat_factor)+suffix_save+".txt"), "w") as txt_file:
        print("-----------------------------\nModel\nfeat_factor ", feat_factor,file=txt_file)
        print("Num samples [[TN,FP],[FN,TP]]", file=txt_file)
        print(conf_matrix, file=txt_file)
        compute_metrics(conf_matrix, txt_file=txt_file)
    
    model.save(os.path.join(path_output, "model"+suffix_save+"_fullmodel.h5"))
    model.save_weights(os.path.join(path_output, "model"+suffix_save+".h5"))
    with open(os.path.join(path_output, "model"+suffix_save+".json"), "w") as json_file:
        json_file.write(model.to_json())
        
        
    vmodel = vanillaModel(model)
    vmodel.summary()    
    vmodel.save(os.path.join(path_output, "vmodel"+suffix_save+"_fullmodel.h5"))
    vmodel.save_weights(os.path.join(path_output, "vmodel"+suffix_save+".h5"))
    with open(os.path.join(path_output, "vmodel"+suffix_save+".json"), "w") as json_file:
        json_file.write(vmodel.to_json())
    # 5. W&B: plot prediction values
    #if wandb_active:
    #    preds = model.predict(ds_test)
    #    utils_train.wandb_log_plot_predictions(preds)
    #model.save_weights(os.path.join('.', "imagenet_quantmodel_weights.h5"))
    #with open(os.path.join('.', "imagenet_quantmodel.json"), "w") as json_file:
    #    json_file.write(model.to_json())



