#!/usr/bin/env python
# coding: utf-8
# %%
# tf.__version__==1.14

# %%


from __future__ import absolute_import, division, print_function
import os, sys, shutil, glob, datetime, time, csv, pydot

import matplotlib
matplotlib.use("agg")
from matplotlib import pylab as plt
from sklearn.metrics import roc_curve, auc, roc_auc_score, confusion_matrix
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.keras import layers 
import optuna

import resnet50

#layers = tf.keras.layers


# %%


# GPU config and eager execution
gpu = "7"    ############### 
gpu_config = tf.ConfigProto(
    gpu_options=tf.GPUOptions(
        visible_device_list=gpu, allow_growth=True),
    log_device_placement=False,
    allow_soft_placement=True)
tf.enable_eager_execution(config=gpu_config)

#np.random.seed(7)
#tf.set_random_seed(7)

tfe = tf.contrib.eager


# ### Parameters

# %%
#Params
display_step = 10
name_dataset = "CIFAR100"
nb_cls = 2
protocol = "Ntot" # imbalance protocol: fix "Nmnr", "Nmjr", or "Ntot"
loss_type = "softmax_cross_entropy" # softmax_cross_entropy, focal, cost_sensitive_softmax_cross_entropy, cost_sensitive_focal, softmax_cross_entropy_OS ###
exp_phase = "stat" # tuning, stat ###

if name_dataset == "CIFAR100":
    IR = 1               #  99,  50,  10,   5,  1 ###
    n_trials = 2        #   2,   4,  10,  25, 50 ###
    nb_epoch = 40        # 100, 110, 120, 120, 40 ###
    valid_epoch_step = 2 #   1,    1,  2,   2,  2 ###
    #IR = 500              # 990, 500, 100,  50, 10,   1 ###
    #n_trials = 10         #   2,   4,  10,  25, 50, 300 ###
    #nb_epoch = 110        # 100, 110, 120, 120, 40,  40 ###
    #valid_epoch_step = 1  #   1,   1,   1,   2,  2,   2 ###

    if protocol == "Nmnr":
        Nmnr = 400
        Nmjr = int(Nmnr * IR)
        name_change = "Nmjr"
        num_change = Nmjr
        num_fix = Nmnr
        
        if Nmnr == 400: 
            lr_dict_sce = {1:0.1 , 5:0.01, 10:0.01, 50:0.001, 99:0.001} # for "sce"
            lr_dict_focal = {1:3.8e-3 , 5:1.2e-2, 10:8.1e-3, 50:1.4e-3, 99:1.6e-3} # for "focal"
            gamma_dict_focal = {1:3.5920, 5:1.1123, 10:1.0482, 50:3.9070, 99:1.4719} # for "focal"
            lr_dict_cssce =  {1:0.1, 5:0.01, 10:0.01, 50:0.001, 99:0.001} # for "cssce"
            beta_dict_cssce =  {1:0.99999, 5:0.99999, 10:0.99999, 50:0.999, 99:0.99999}
            lr_dict_sceos =  {1:0.01, 5:0.01, 10:0.06, 50:0.035, 99:0.02}
            lr_dict_csfocal = {1:0.01, 5:0.01, 10:0.01, 50:0.001, 99:0.001}
            beta_dict_csfocal = {1:0.999, 5:0.999, 10:0.99, 50:0.99, 99:0.99}
            gamma_dict_csfocal =  {1:2.43, 5:2.05, 10:1.32, 50:1.09, 99:1.47} 
        elif Nmnr == 40:
            #lr_dict_sce = {1:0.1 , 5:0.01, 10:0.01, 50:0.001, 99:0.001} # for "sce"
            #lr_dict_focal = {1:3.8e-3 , 5:1.2e-2, 10:8.1e-3, 50:1.4e-3, 99:1.6e-3} # for "focal"
            #gamma_dict_focal = {1:3.5920, 5:1.1123, 10:1.0482, 50:3.9070, 99:1.4719} # for "focal"
            #lr_dict_cssce =  {1:0.1, 5:0.01, 10:0.01, 50:0.001, 99:0.001} # for "cssce"
            #beta_dict_cssce =  {1:0.99999, 5:0.99999, 10:0.99999, 50:0.999, 99:0.99999}
            #lr_dict_sceos =  {1:0.01, 5:0.01, 10:0.06, 50:0.035, 99:0.02}
            #lr_dict_csfocal = {1:0.01, 5:0.01, 10:0.01, 50:0.001, 99:0.001}
            #beta_dict_csfocal = {1:0.999, 5:0.999, 10:0.99, 50:0.99, 99:0.99}
            #gamma_dict_csfocal =  {1:2.43, 5:2.05, 10:1.32, 50:1.09, 99:1.47}        
            pass        
        
    elif protocol == "Ntot":
        Ntot = 800
        Nmnr = int(Ntot/(1+IR))
        Nmjr = Ntot - Nmnr
        name_change = "Nmnr"
        num_change = Nmnr
        num_fix = Nmnr + Nmjr
        
        if Ntot == 800:
            #lr_dict_sce = {1:0.1 , 5:0.01, 10:0.01, 50:0.001, 99:0.001} # for "sce"
            #lr_dict_focal = {1:3.8e-3 , 5:1.2e-2, 10:8.1e-3, 50:1.4e-3, 99:1.6e-3} # for "focal"
            #gamma_dict_focal = {1:3.5920, 5:1.1123, 10:1.0482, 50:3.9070, 99:1.4719} # for "focal"
            #lr_dict_cssce =  {1:0.1, 5:0.01, 10:0.01, 50:0.001, 99:0.001} # for "cssce"
            #beta_dict_cssce =  {1:0.99999, 5:0.99999, 10:0.99999, 50:0.999, 99:0.99999}
            #lr_dict_sceos =  {1:0.01, 5:0.01, 10:0.06, 50:0.035, 99:0.02}
            #lr_dict_csfocal = {1:0.01, 5:0.01, 10:0.01, 50:0.001, 99:0.001}
            #beta_dict_csfocal = {1:0.999, 5:0.999, 10:0.99, 50:0.99, 99:0.99}
            #gamma_dict_csfocal =  {1:2.43, 5:2.05, 10:1.32, 50:1.09, 99:1.47}        
            pass        

    elif protocol == "Nmjr":
        Nmjr = 400
        Nmnr = int(Nmjr/IR)
        name_change = "Nmnr"
        num_change = Nmnr
        num_fix = Nmjr    
    else:
        raise ValueError
    
    ls_irtest = [1, 5, 10, 50, 99]
    


elif name_dataset == "CelebA":
    """IR = 1  # 99, 50, 10, 5, 1 ###
    n_trials = 2 # 2, 4, 15, 25, 50 ###
    nb_epoch = 40 # 100, 110, 120, 120, 40 ###
    valid_epoch_step = 2 # 1, 1, 2, 2, 2 ###

    lr_dict_sce = {1:3.1e-5 , 5:1e-3, 10:7.3e-4, 50:3.6e-4, 99:9.8e-4} # for "sce"
    lr_dict_focal = {1:3.8e-3 , 5:1.2e-2, 10:8.1e-3, 50:1.4e-3, 99:1.6e-3} # for "focal"
    gamma_dict_focal = {1:3.5920, 5:1.1123, 10:1.0482, 50:3.9070, 99:1.4719} # for "focal"
    #lr_dict_cssce =  {1:, 5:, 10:, 50:, 99:}
    #beta_dict_cssce =  {1:, 5:, 10:, 50:, 99:}
    #lr_dict_csfocal =  {1:, 5:, 10:, 50:, 99:}
    #beta_dict_csfocal = {1:, 5:, 10:, 50:, 99:}
    #gamma_dict_csfocal =  {1:, 5:, 10:, 50:, 99:}    
    #lr_dict_sceos =  {1:, 5:, 10:, 50:, 99:}
    IRtest_list = [1, 5, 10, 50, 99]

    if protocol == "Nmnr":
        Nmnr = 400
        Nmjr = int(Nmnr * IR)
        name_change = "Nmjr"
        num_change = Nmjr
        num_fix = Nmnr        
    elif protocol == "Ntot":
        Nmnr = 0
        Nmjr = 0
        name_change = "Nmnr"
        num_change = Nmnr
        num_fix = Nmnr + Nmjr
        xxx
    elif protocol == "Nmjr":
        Nmnr = 0
        Nmjr = 0
        name_change = "Nmnr"
        num_change = Nmnr
        num_fix = Nmjr
        xxx"""
    pass
else:
    raise ValueError
    
# Sqlite3
study_name = '{}_{}{}_IR{}_{}_{}'.format(name_dataset, protocol, num_fix, IR, loss_type,  exp_phase)
storage = 'sqlite:///{}_{}{}_IR{}_{}_{}.db'.format(name_dataset, protocol, num_fix, IR, loss_type, exp_phase)


# %%
def objective(trial):
    bs = 128
    if loss_type == "softmax_cross_entropy":
        if exp_phase == "tuning":
            if IR > 10:
                lr = trial.suggest_categorical('learning_rate', [1e-5,1e-4,1e-3,1e-2])
            else:
                lr = trial.suggest_categorical('learning_rate', [1e-5,1e-4,1e-3,1e-2,1e-1])

        elif exp_phase == "stat":
            lr = lr_dict_sce[IR]

    elif loss_type == "softmax_cross_entropy_OS":
        if exp_phase == "tuning":
            if IR > 10:
                lr = trial.suggest_categorical('learning_rate', [1e-5,1e-4,1e-3,1e-2])
            else:
                lr = trial.suggest_categorical('learning_rate', [1e-5,1e-4,1e-3,1e-2,1e-1])
                
        elif exp_phase == "stat":
            lr = lr_dict_sce[IR]
            
    elif loss_type == "cost_sensitive_softmax_cross_entropy":
        if exp_phase == "tuning":
            if IR > 10:
                lr = trial.suggest_categorical('learning_rate', [1e-5,1e-4,1e-3,1e-2])
            else:
                lr = trial.suggest_categorical('learning_rate', [1e-5,1e-4,1e-3,1e-2,1e-1])
                
            if Nmnr != 400:
                beta = trial.suggest_categorical('beta', [0.9, 0.99, 0.999, 0.9999, 0.99999])
            else:
                beta = trial.suggest_categorical('beta', [0.99, 0.999, 0.9999, 0.99999])

        elif exp_phase == "stat":
            lr = lr_dict_cssce[IR]
            beta = beta_dict_cssce[IR]

    elif loss_type == "cost_sensitive_focal":
        if exp_phase == "tuning":
            if IR > 10:
                lr = trial.suggest_categorical('learning_rate', [1e-5,1e-4,1e-3,1e-2])
            else:
                lr = trial.suggest_categorical('learning_rate', [1e-5,1e-4,1e-3,1e-2,1e-1])

            if Nmnr != 400:
                beta = trial.suggest_categorical('beta', [0.9, 0.99, 0.999, 0.9999, 0.99999])
            else:
                beta = trial.suggest_categorical('beta', [0.99, 0.999, 0.9999, 0.99999])
                
            gamma = trial.suggest_uniform('gamma', 0., 4.)
            
        elif exp_phase == "stat":
            lr = lr_dict_csfocal[IR]
            beta = beta_dict_csfocal[IR]
            gamma = gamma_dict_csfocal[IR]
        
    elif loss_type == "focal":
        if exp_phase == "tuning":
            if IR > 10:
                lr = trial.suggest_categorical('learning_rate', [1e-5,1e-4,1e-3,1e-2])
            else:
                lr = trial.suggest_categorical('learning_rate', [1e-5,1e-4,1e-3,1e-2,1e-1])
                
            gamma = trial.suggest_uniform('gamma', 0., 4.)
            
        elif exp_phase == "stat":
            lr = lr_dict_focal[IR]
            gamma = gamma_dict_focal[IR]

    else:
        raise ValueError

    print("IR: ", IR)
    print("Nmnr: ", Nmnr)
    print("Nmjr: ", Nmjr)
    print("bs ", bs)
    print("lr ", lr)
    
    # tblogdir_home
    tblog_prefix = "IR{},{}{},lr{:.7f},bs{}_".format(IR, name_change, num_change, lr, bs)
    tblogdir_home = "/data/t-miyagawa/logs/imbalance/tblogs/{}/mnrcls_0_{}_{}_{}_{}/IR{:04d}_{}_{:05d}".format(
        name_dataset, protocol, num_fix, loss_type, exp_phase, IR, name_change, num_change)
    if not os.path.exists(tblogdir_home):
        os.makedirs(tblogdir_home)
        print("Made: ", tblogdir_home)
        
    # ### Load data

    # In[4]:


    # Load data
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data()
    (x_valid, y_valid) = (x_train[40000:], y_train[40000:])
    (x_train, y_train) = (x_train[:40000], y_train[:40000])
    # Flatten
    y_train = np.array([value[0] for _, value in enumerate(y_train)])
    y_valid = np.array([value[0] for _, value in enumerate(y_valid)])
    y_test = np.array([value[0] for _, value in enumerate(y_test)])
    print("train shape: ", x_train.shape, y_train.shape)
    print("test  shape: ", x_test.shape, y_test.shape)
    print("valid shape: ", x_valid.shape, y_valid.shape)
    print(x_train.shape[0], "train samples")
    print(x_valid.shape[0], "valid samples")
    print(x_test.shape[0], "test samples")

    # In[5]:

    # Class generation
    y_train = np.array([0 if value == 0 else 1 for idx, value in enumerate(y_train)])
    y_valid = np.array([0 if value == 0 else 1 for idx, value in enumerate(y_valid)])
    y_test = np.array([0 if value == 0 else 1 for idx, value in enumerate(y_test)])
    print("original y_train", y_train)


    # Delete unnecessary data
    print("Deleting unnecessary data...")

    idx_train_mjr = np.where(y_train==1)[0]
    idx_train_mjr_remove = idx_train_mjr[Nmjr:]
    x_train = np.delete(x_train, idx_train_mjr_remove, axis=0)
    y_train = np.delete(y_train, idx_train_mjr_remove, axis=0)

    idx_train_mjr = np.where(y_train==0)[0]
    idx_train_mjr_remove = idx_train_mjr[Nmnr:]
    x_train = np.delete(x_train, idx_train_mjr_remove, axis=0)
    y_train = np.delete(y_train, idx_train_mjr_remove, axis=0)

    print("Done:")
    print(x_train.shape[0], "train samples")
    print(x_valid.shape[0], "valid samples")
    print(x_test.shape[0], "test samples")

    if x_train.shape[0] < Nmjr+Nmnr:
        print("\nWorning!!")
        print("Maybe user-defined value of Nmjr+Nmnr exceeds maximum possible Nmjr+Nmnr")    

    # Generate imbalanced test data
    dic_x_test = dict()
    dic_y_test = dict()

    for iter_irtest in ls_irtest:
        Nmnr_test = 100
        Nmjr_test = int(Nmnr_test * int(iter_irtest))
        idx_test_mjr = np.where(y_test==1)[0]
        idx_test_mjr_remove = idx_test_mjr[Nmjr_test:]
        dic_x_test[iter_irtest] = np.delete(x_test, idx_test_mjr_remove, axis=0)
        dic_y_test[iter_irtest] = np.delete(y_test, idx_test_mjr_remove, axis=0)
        #print("len(x_test) = total nb of test data: ", len(x_test))

        
    # In[6]:

    # Over-sampling
    if loss_type == "softmax_cross_entropy_OS":
        print("Over-sampling with IR={}...".format(IR))
        print("Num of minor class data points: from ", len(np.where(y_train==0)[0]))

        if IR != 1:
            mult = int(Nmjr/Nmnr)
            idx_train_mnr = np.where(y_train==0)[0]
            x_train_mnr = x_train[idx_train_mnr]
            y_train_mnr = y_train[idx_train_mnr]
            for i in range(mult - 1):
                x_train = np.append(x_train, x_train_mnr, axis=0)
                y_train = np.append(y_train, y_train_mnr, axis=0)

    print("Shuffling...")
    idx_perm = np.random.permutation(len(x_train))
    x_train = x_train[idx_perm]
    y_train = y_train[idx_perm]
    print("to ", len(np.where(y_train==0)[0]))

    
    # In[7]:


    # Data preprocessing
    x_train = np.float32(x_train)
    x_valid = np.float32(x_valid)
    x_train /= 127.5
    x_train -= -1
    x_valid /= 127.5
    x_valid -= -1
    for iter_irtest in ls_irtest:
        dic_x_test[iter_irtest] = np.float32(dic_x_test[iter_irtest])
        dic_x_test[iter_irtest] /= 127.5
        dic_x_test[iter_irtest] -= -1


    # In[8]:


    # Dataset
    print(y_train)
    dataset = tf.data.Dataset.from_tensor_slices((x_train,y_train))
    dataset = dataset.shuffle(40000) # .shuffle -> .batch !!!
    dataset = dataset.batch(bs, drop_remainder=False) # .shuffle -> .batch !!!
    dataset_valid = tf.data.Dataset.from_tensor_slices((x_valid,y_valid)).batch(bs, drop_remainder=False)
    dataset_test = dict()
    for iter_irtest in ls_irtest:
        dataset_test[iter_irtest] = tf.data.Dataset.from_tensor_slices(
            (dic_x_test[iter_irtest], dic_y_test[iter_irtest])).batch(bs, drop_remainder=False) 

    # In[9]:


    # Model
    def get_model():
        model = resnet50.ResNet50(
            data_format="channels_last", 
            name='ResNet', 
            include_top=False, 
            pooling='avg')

        model.flatten = layers.Flatten()
        model.fc100 = layers.Dense(nb_cls, name="fc100")

        return model

    model = get_model()

    # Show layers
    #model.l2a.trainable_variables[3].name
    for iter_layer in model.layers:
        print("Layer: ", iter_layer.name)

    # In[10]:
    def loss_fn(model, x, y, loss_type=loss_type, training=True):
        # loss_type = "softmax_cross_entropy", "focal", "cost_sensitive_softmax_cross_entropy", "cost_sensitive_focal"
        if loss_type == "softmax_cross_entropy":
            bottleneck_features = model(x, training=training)
            logits = model.fc100(model.flatten(bottleneck_features))
            return tf.losses.sparse_softmax_cross_entropy(labels=y, logits=logits), logits, bottleneck_features 

        elif loss_type == "softmax_cross_entropy_OS":
            bottleneck_features = model(x, training=training)
            logits = model.fc100(model.flatten(bottleneck_features))
            return tf.losses.sparse_softmax_cross_entropy(labels=y, logits=logits), logits, bottleneck_features 

        elif loss_type == "focal":
            return focal_loss(model, x, y, gamma=gamma, training=training)

        elif loss_type == "cost_sensitive_focal":
            _, alpha = cost_matrix_generator(beta, y)
            return focal_loss(model, x, y, alpha=alpha, gamma=gamma, training=training)

        elif loss_type == "cost_sensitive_softmax_cross_entropy":
            return cost_sensitive_loss(model, x, y, beta, training=training)

        else:
            raise KeyError()

    # In[ ]:

    def cost_matrix_generator(beta, y):
        # y: non-onehot, beta: scalar
        power_list = [[Nmnr, Nmjr] if value == 0 else [Nmjr, Nmnr] for _, value in enumerate(y.numpy())]
        cost_matrix = (1. - beta)/(1. - tf.pow(beta, power_list) + tf.keras.backend.epsilon()) # shape=(batch_size, nb_cls=2)

        pwls_focal = [Nmnr, Nmjr]
        alpha = (1. - beta)/(1. - tf.pow(beta, pwls_focal) + tf.keras.backend.epsilon()) # for focal loss
        
        return cost_matrix, alpha


    # In[ ]:


    def cost_sensitive_loss(model, x, y, beta, training=True):
        bottleneck_features = model(x, training=training)
        logits = model.fc100(model.flatten(bottleneck_features))

        cost_matrix, _ = cost_matrix_generator(beta, y) # (batch, nb_cls=2)
        weighted_ce_loss = tf.losses.sparse_softmax_cross_entropy(labels=y, logits=logits, weights=cost_matrix[:,0])
        
        return weighted_ce_loss, logits, bottleneck_features 


    # In[ ]:


    def focal_loss(model, x, y, alpha=None, gamma=2.0, training=True):
        """
        Compute sigmoid focal loss between logits and onehot labels: focal loss = -alpha_t(1-p_t)^gamma*log(p_t)
        Args:
            onehot_labels: onehot labels with shape (batch_size, num_classes)
            logits: last layer feature output with shape (batch_size, num_classes)
            alpha: The hyperparameter for adjusting biased samples, with shape (batch_size num_classes),
                   default is 0.25 for the minority class in the binary classification.
                   NOT assumed to be normalized to be 1.
            gamma: The hyperparameter for penalizing the easy labeled samples, default is 2.0
        Returns:
            a scalar of focal loss of total classification
        """
        with tf.name_scope("focal_loss"):
            bottleneck_features = model(x, training=training)
            logits = model.fc100(model.flatten(bottleneck_features))
            onehot_labels = tf.one_hot(y, nb_cls, dtype=tf.float32)

            ce = tf.nn.sigmoid_cross_entropy_with_logits(labels=onehot_labels, logits=logits) # (batch, nb_cls)
            predictions = tf.sigmoid(logits) # (batch, nb_cls)
            predictions_pt = tf.where(tf.equal(onehot_labels, 1), predictions, 1.-predictions)
                # correct cls: (1-p_k)^\ga, incorrect cls: p_k^\ga (k: class)

            if alpha is None:
                alpha = tf.ones_like(onehot_labels, dtype=tf.float32) / nb_cls
            else:
                alpha = tf.expand_dims(alpha, axis=0)
                
            weighted_loss = ce * tf.pow(1. - predictions_pt, gamma) * alpha
          
            return tf.reduce_sum(weighted_loss), logits, bottleneck_features


    def grad_fn(model, inputs, targets, training=True):
        with tf.GradientTape() as tape:
            loss, logits, _ = loss_fn(model, inputs, targets, training=training)
        return tape.gradient(loss, model.variables), loss, logits


    # In[11]:


    class MetricFunction():
        def __init__(self, cm=None):
            if cm is None:
                self.metrics = {
                    "SNS":{**{i: 0. for i in range(nb_cls)}, **{"micro": 0., "macro": 0.}}, #TPR
                    "SPC":{**{i: 0. for i in range(nb_cls)}, **{"micro": 0., "macro": 0.}}, #TNR = Recall
                    "PRC":{**{i: 0. for i in range(nb_cls)}, **{"micro": 0., "macro": 0.}},
                    "ACC":{**{i: 0. for i in range(nb_cls)}, **{"micro": 0., "macro": 0., "original": 0.}},
                    "BAC":{**{i: 0. for i in range(nb_cls)}, **{"micro": 0., "macro": 0.}},
                    "F1":{**{i: 0. for i in range(nb_cls)}, **{"micro": 0., "macro": 0.}},
                    "GM":{**{i: 0. for i in range(nb_cls)}, **{"micro": 0., "macro": 0.}}, 
                    "MC":{**{i: -1. for i in range(nb_cls)}, **{"micro": -1., "macro": -1.}},
                    "MK":{**{i: -1. for i in range(nb_cls)}, **{"micro": -1., "macro": -1.}},
                    "SNS50":{**{i: 0. for i in range(nb_cls)}, **{"micro": 0., "macro": 0.}},
                    "SNS70":{**{i: 0. for i in range(nb_cls)}, **{"micro": 0., "macro": 0.}},
                    "SNS90":{**{i: 0. for i in range(nb_cls)}, **{"micro": 0., "macro": 0.}},
                    "SPC50":{**{i: 0. for i in range(nb_cls)}, **{"micro": 0., "macro": 0.}},
                    "SPC70":{**{i: 0. for i in range(nb_cls)}, **{"micro": 0., "macro": 0.}},
                    "SPC90":{**{i: 0. for i in range(nb_cls)}, **{"micro": 0., "macro": 0.}},
                    "PRC50":{**{i: 0. for i in range(nb_cls)}, **{"micro": 0., "macro": 0.}},
                    "PRC70":{**{i: 0. for i in range(nb_cls)}, **{"micro": 0., "macro": 0.}},
                    "PRC90":{**{i: 0. for i in range(nb_cls)}, **{"micro": 0., "macro": 0.}},
                    "REC50":{**{i: 0. for i in range(nb_cls)}, **{"micro": 0., "macro": 0.}},
                    "REC70":{**{i: 0. for i in range(nb_cls)}, **{"micro": 0., "macro": 0.}},
                    "REC90":{**{i: 0. for i in range(nb_cls)}, **{"micro": 0., "macro": 0.}}
                }
                print("cm is not given: initialized with the loweset values.")
            else:
                self.metrics = {
                    "SNS":dict(),
                    "SPC":dict(),
                    "PRC":dict(),
                    "ACC":dict(),
                    "BAC":dict(),
                    "F1":dict(),
                    "GM":dict(),
                    "MC":dict(),
                    "MK":dict(),
                    "SNS50":dict(),
                    "SNS70":dict(),
                    "SNS90":dict(),
                    "SPC50":dict(),
                    "SPC70":dict(),
                    "SPC90":dict(),
                    "PRC50":dict(),
                    "PRC70":dict(),
                    "PRC90":dict(),
                    "REC50":dict(),
                    "REC70":dict(),
                    "REC90":dict()
                }
                print("MetricFunction.metrics: Initialized.")
                assert nb_cls == len(cm)            
                TP_tot = 0
                TN_tot = 0
                FP_tot = 0
                FN_tot = 0
                for i in range(nb_cls):
                    # Initialization
                    TP = 0
                    TN = 0
                    FP = 0
                    FN = 0
                    # Calc
                    TP = cm[i,i]

                    for j in range(nb_cls):
                        if j == i:
                            continue
                        FP += cm[j,i]
                        FN += cm[i,j]
                        for k in range(nb_cls):
                            if k == i:
                                continue
                            TN += cm[j,k]

                    self.metrics["SNS"][i] = TP/(TP+FN)
                    self.metrics["SPC"][i] = TN/(TN+FP)
                    self.metrics["PRC"][i] = TP/(TP+FP) if TP+FP != 0 else 0.
                    self.metrics["ACC"][i] = (TP+TN)/(TP+FN+TN+FP)
                    self.metrics["BAC"][i] = ((TP/(TP+FN)) + (TN/(TN+FP)))/2
                    self.metrics["F1"][i] = 2*(self.metrics["PRC"][i] * self.metrics["SNS"][i]) / (self.metrics["PRC"][i] + self.metrics["SNS"][i])
                    self.metrics["GM"][i] = np.sqrt(self.metrics["SNS"][i] * self.metrics["SPC"][i])
                    self.metrics["MC"][i] = ((TP*TN) - (FP*FN))/(np.sqrt( (TP+FP)*(TP+FN)*(TN+FP)*(TN+FN) ))
                    self.metrics["MK"][i] = (TP/(TP+FP)) + (TN/(TN+FN)) - 1 if TN+FN != 0 else self.metrics["PRC"][i] - 1
                    self.metrics["SNS50"][i] = self.metrics["SNS"][i]
                    self.metrics["SNS70"][i] = self.metrics["SNS"][i]
                    self.metrics["SNS90"][i] = self.metrics["SNS"][i]
                    self.metrics["PRC50"][i] = self.metrics["PRC"][i]
                    self.metrics["PRC70"][i] = self.metrics["PRC"][i]
                    self.metrics["PRC90"][i] = self.metrics["PRC"][i]
                    self.metrics["REC50"][i] = self.metrics["SNS"][i]
                    self.metrics["REC70"][i] = self.metrics["SNS"][i]
                    self.metrics["REC90"][i] = self.metrics["SNS"][i]

                    TP_tot += TP
                    TN_tot += TN
                    FP_tot += FP
                    FN_tot += FN

                self.metrics["SNS"]["macro"] = np.mean([self.metrics["SNS"][i] for i in range(nb_cls)])
                self.metrics["SNS"]["micro"] = TP_tot/(TP_tot+FN_tot) # = original ACC. inappropreate metric
                self.metrics["SPC"]["macro"] = np.mean([self.metrics["SPC"][i] for i in range(nb_cls)])
                self.metrics["SPC"]["micro"] =  TN_tot/(TN_tot+FP_tot)
                self.metrics["PRC"]["macro"] = np.mean([self.metrics["PRC"][i] for i in range(nb_cls)])
                self.metrics["PRC"]["micro"] = TP_tot/(TP_tot+FP_tot) # = original ACC. inappropreate metric
                self.metrics["ACC"]["macro"] = np.mean([self.metrics["ACC"][i] for i in range(nb_cls)])
                self.metrics["ACC"]["micro"] = (TP_tot+TN_tot)/(TP_tot+FN_tot+TN_tot+FP_tot)
                self.metrics["ACC"]["original"] = ((nb_cls/2) * self.metrics["ACC"]["micro"]) - ((nb_cls-2)/2)
                self.metrics["BAC"]["macro"] = np.mean([self.metrics["BAC"][i] for i in range(nb_cls)])
                self.metrics["BAC"]["micro"] = ((TP_tot/(TP_tot+FN_tot)) + (TN_tot/(TN_tot+FP_tot)))/2
                self.metrics["F1"]["macro"] = np.mean([self.metrics["F1"][i] for i in range(nb_cls)])
                self.metrics["F1"]["micro"] = 2*(self.metrics["PRC"]["micro"] * self.metrics["SNS"]["micro"]) / (self.metrics["PRC"]["micro"] + self.metrics["SNS"]["micro"]) # # = original ACC. inappropreate metric
                self.metrics["GM"]["macro"] = np.mean([self.metrics["GM"][i] for i in range(nb_cls)])
                self.metrics["GM"]["micro"] = np.sqrt(self.metrics["SNS"]["micro"] * self.metrics["SPC"]["micro"])
                self.metrics["MC"]["macro"] = np.mean([self.metrics["MC"][i] for i in range(nb_cls)])
                self.metrics["MC"]["micro"] = ((TP_tot*TN_tot) - (FP_tot*FN_tot))/(np.sqrt( (TP_tot+FP_tot)*(TP_tot+FN_tot)*(TN_tot+FP_tot)*(TN_tot+FN_tot) ))
                self.metrics["MK"]["macro"] = np.mean([self.metrics["MK"][i] for i in range(nb_cls)])
                self.metrics["MK"]["micro"] = (TP_tot/(TP_tot+FP_tot)) + (TN_tot/(TN_tot+FN_tot)) - 1 
                self.metrics["SNS50"]["macro"] = self.metrics["SNS"]["macro"]
                self.metrics["SNS50"]["micro"] = self.metrics["SNS"]["micro"]
                self.metrics["SNS70"]["macro"] = self.metrics["SNS"]["macro"]
                self.metrics["SNS70"]["micro"] = self.metrics["SNS"]["micro"]
                self.metrics["SNS90"]["macro"] = self.metrics["SNS"]["macro"]
                self.metrics["SNS90"]["micro"] = self.metrics["SNS"]["micro"]
                self.metrics["REC50"]["macro"] = self.metrics["SNS"]["macro"]
                self.metrics["REC50"]["micro"] = self.metrics["SNS"]["micro"]
                self.metrics["REC70"]["macro"] = self.metrics["SNS"]["macro"]
                self.metrics["REC70"]["micro"] = self.metrics["SNS"]["micro"]
                self.metrics["REC90"]["macro"] = self.metrics["SNS"]["macro"]
                self.metrics["REC90"]["micro"] = self.metrics["SNS"]["micro"]
                self.metrics["PRC50"]["macro"] = self.metrics["PRC"]["macro"]
                self.metrics["PRC50"]["micro"] = self.metrics["PRC"]["micro"]
                self.metrics["PRC70"]["macro"] = self.metrics["PRC"]["macro"]
                self.metrics["PRC70"]["micro"] = self.metrics["PRC"]["micro"]
                self.metrics["PRC90"]["macro"] = self.metrics["PRC"]["macro"]
                self.metrics["PRC90"]["micro"] = self.metrics["PRC"]["micro"]
                
                print("MetricFunction.metrics: Finished calculation.")


    # In[12]:


    def cm_auc_fn(logits, y, flag_auc=False): ################
        preds = tf.argmax(logits, axis=1, output_type=tf.int32).numpy()

        if (type(y) is not list) and (type(y) is not type(np.ndarray([]))):
            labels = y.numpy()
        else:
            labels = y
        if (type(logits) is not list) and (type(logits) is not type(np.ndarray([]))):
            scores = logits.numpy()
        else:
            scores = logits

        cm = confusion_matrix(y_true=labels, y_pred=preds, labels=[i for i in range(nb_cls)]) # subscript: (label, pred)

        if flag_auc:
            # Binarize        
            labels_oh = np.eye(nb_cls)[labels]
            preds_oh = np.eye(nb_cls)[preds]

            # Calc AUC
            fpr = dict()
            tpr = dict()
            roc_auc = dict()
            tmp_list = []
            for i in range(nb_cls):
                fpr[i], tpr[i], _ = roc_curve(labels_oh[:,i], scores[:,i])
                roc_auc[i] = auc(fpr[i], tpr[i])
                tmp_list.append(roc_auc[i])

            roc_auc['macro'] = np.mean(tmp_list)
            #fpr['micro'], tpr['micro'], _ = roc_curve(labels_oh.flatten(), scores.flatten())   
            #roc_auc['micro'] = auc(fpr['micro'], tpr['micro'])

            return cm, roc_auc

        return cm, None


    # In[13]:


    # #### debug ####
    # cm = np.array([
    #     [1,9,3],
    #     [7,500,5],
    #     [3,1,1]])
    # nb_cls = 3

    # metric = MetricFunction(cm)
    # metric.metrics
    # ################
    # from IPython.core.debugger import Pdb; Pdb().set_trace() 


    # In[14]:


    # Check point paths
    dir_ckptpath = "/data/t-miyagawa/logs/imbalance/ckpt/{}/mnrcls_0_{}_{}_{}_{}/IR{:04d}_{}_{:05d}".format(
        name_dataset, protocol, num_fix, loss_type, exp_phase, IR, name_change, num_change)

    if not os.path.exists(dir_ckptpath):
        os.makedirs(dir_ckptpath)
        print("Made: ", dir_ckptpath)

    list_ckpts = glob.glob(dir_ckptpath + "/ckpt*")
    if len(list_ckpts) != 0:
        print("\nckpts\n", list_ckpts)
        dir_ckptpath_bkp = dir_ckptpath.replace("/ckpt/", "/ckpt.bkp/")
        if not os.path.exists(dir_ckptpath_bkp):
            os.makedirs(dir_ckptpath_bkp)
            print("Made: ", dir_ckptpath_bkp)
        now = datetime.datetime(1,1,1).now().isoformat().replace("-","").replace(":",".")[:-3]
        for iter_ckpt in list_ckpts:            
            shutil.move(iter_ckpt, iter_ckpt + str(now))
            shutil.move(iter_ckpt + str(now), dir_ckptpath_bkp)
            print("Moved to ckpt.bkp:\nFile {}\nto\nDir {}".format(iter_ckpt, dir_ckptpath_bkp))


    # ### Training

    # In[15]:


    def get_optimizer():
        return tf.train.AdamOptimizer(learning_rate=lr)


    # Training phase
    optimizer = get_optimizer()
    global_step = tf.train.get_or_create_global_step()
    global_step.assign(0)
    
    # Create checkpoint
    ckpt = tf.train.Checkpoint(optimizer=optimizer, model=model)

    # Start training
    now = time.strftime("%Y%m%d%H%M%S")
    tblogdir = tblogdir_home + "/"+ str(tblog_prefix) + str(now)
    os.makedirs(tblogdir)
    summary_writer = tf.contrib.summary.create_file_writer(tblogdir, flush_millis=10000)
    with summary_writer.as_default(), tf.contrib.summary.always_record_summaries():
        # Initialization
        dictbest = MetricFunction(cm=None)
        dictbest_test = dict()
        for iter_irtest in ls_irtest:
            dictbest_test[iter_irtest] = MetricFunction(cm=None)
        flag_save = False
        best = 0.
        
        # Training loop 1
        for epoch in range(nb_epoch):
            # Training loop 2
            for iter_b, (x_trainb, y_trainb) in enumerate(dataset):
                # Calc loss and grad, and backpropagation 
                grads, loss, logits = grad_fn(model, x_trainb, y_trainb)
                optimizer.apply_gradients(zip(grads, model.variables), global_step)

                #Verbose
                if ((iter_b+1)%display_step == 0) or iter_b == 0:    
                    cm, rocauc = cm_auc_fn(logits, y_trainb, flag_auc=False)
                    metgen = MetricFunction(cm)
                    print("Epoch {}, Step{}: loss={:.7f}, ACC={:.7f}, mnrACC={:.7f}".format(
                        epoch+1, iter_b+1, loss, metgen.metrics["ACC"]["original"], metgen.metrics["ACC"][0]))
                    print("Epoch {}, Step{}:   mnrSNS={:.7f}, mnrF1={:.7f},  mnrGM={:.7f},  mnrMC={:.7f}, mnrBAC={:.7f}".format(
                        epoch+1, iter_b+1,
                        metgen.metrics["SNS"][0], 
                        metgen.metrics["F1"][0], 
                        metgen.metrics["GM"][0], 
                        metgen.metrics["MC"][0],
                        metgen.metrics["BAC"][0]))
                    print(cm)

                    # Tensorboard
                    tf.contrib.summary.scalar("train/loss", loss)
                    tf.contrib.summary.scalar("train/ACC/org", metgen.metrics["ACC"]["original"])
                    tf.contrib.summary.scalar("train/SNS/0", metgen.metrics["SNS"][0])
                    tf.contrib.summary.scalar("train/SNS/1", metgen.metrics["SNS"][1])
                    tf.contrib.summary.scalar("train/SPC/0", metgen.metrics["SPC"][0])
                    tf.contrib.summary.scalar("train/SPC/1", metgen.metrics["SPC"][1])                                        
                    tf.contrib.summary.scalar("train/PRC/0", metgen.metrics["PRC"][0])
                    tf.contrib.summary.scalar("train/PRC/1", metgen.metrics["PRC"][1])
                    tf.contrib.summary.scalar("train/BAC/0", metgen.metrics["BAC"][0])                                                            
                    tf.contrib.summary.scalar("train/F1/0", metgen.metrics["F1"][0])
                    tf.contrib.summary.scalar("train/F1/1", metgen.metrics["F1"][1])
                    tf.contrib.summary.scalar("train/GM/0", metgen.metrics["GM"][0])
                    tf.contrib.summary.scalar("train/MC/0", metgen.metrics["MC"][0])                  
                    tf.contrib.summary.scalar("train/MK/0", metgen.metrics["MK"][0])                 
                    
            # Validation
            ###################
            if (epoch == 0) or ((epoch+1) % valid_epoch_step == 0):
                # Lists for validation
                bs_list_valid = []
                loss_list_valid = []
                logit_list_valid = []
                y_validb_list = []
                cm_list_valid = []
                #rocauc_list_valid = []

                # Validation loop
                print("Now validation...")

                for iter_bv, (x_validb, y_validb) in enumerate(dataset_valid):
                    if ((iter_bv+1)%10 == 0) or (iter_bv == 0):
                        sys.stdout.write("\r iter: {}".format(iter_bv+1))
                        sys.stdout.flush()
                    # Calc
                    loss_valid, logits_valid, _ = loss_fn(model, x_validb, y_validb, training=False)
                    cm_valid, _ = cm_auc_fn(logits_valid, y_validb, flag_auc=False)

                    # Append
                    bs_list_valid.append(len(y_validb))
                    loss_list_valid.append(loss_valid)
                    logit_list_valid.extend(logits_valid)
                    y_validb_list.extend(y_validb)
                    cm_list_valid.append(cm_valid)
                print("")
                cm_valid_tot = np.sum(cm_list_valid, axis=0)
                metgen_valid = MetricFunction(cm_valid_tot)

                #print("bs_list_valid, loss_list_valid: ", bs_list_valid, loss_list_valid) ###
                loss_valid = np.sum([i*j.numpy() for i,j in zip(bs_list_valid, loss_list_valid)]) / np.sum(bs_list_valid)

                #Verbose valid 
                print("Epoch {}: vloss={:.7f}, vACC={:.7f}, mnrvACC={:.7f}".format(
                    epoch+1,
                    loss_valid, 
                    metgen_valid.metrics["ACC"]["original"],
                    metgen_valid.metrics["ACC"][0],
                    #rocauc_valid['macro']
                )
                     )
                print("Epoch {}: mnrvACC={:.7f}, mnrvSNS={:.7f}, mnrvF1={:.7f}, mnrvGM={:.7f}, mnrvMC={:.7f}, mnrvBAC={:.7f}".format(
                    epoch+1, 
                    metgen_valid.metrics["ACC"][0],
                    metgen_valid.metrics["SNS"][0],
                    metgen_valid.metrics["F1"][0],
                    metgen_valid.metrics["GM"][0],
                    metgen_valid.metrics["MC"][0],
                    metgen_valid.metrics["BAC"][0]
                )
                     )
                print("Total Confision Matirx for Validation Data:")
                print(cm_valid_tot)

                # Tensorboard validation
                tf.contrib.summary.scalar("valid/loss", loss_valid)
                tf.contrib.summary.scalar("valid/ACC/org", metgen_valid.metrics["ACC"]["original"])
                tf.contrib.summary.scalar("valid/SNS/0", metgen_valid.metrics["SNS"][0])
                tf.contrib.summary.scalar("valid/SPC/0", metgen_valid.metrics["SPC"][0])                                       
                tf.contrib.summary.scalar("valid/PRC/0", metgen_valid.metrics["PRC"][0])
                tf.contrib.summary.scalar("valid/PRC/1", metgen_valid.metrics["PRC"][1])
                tf.contrib.summary.scalar("valid/BAC/0", metgen_valid.metrics["BAC"][0])                                                          
                tf.contrib.summary.scalar("valid/F1/0", metgen_valid.metrics["F1"][0])
                tf.contrib.summary.scalar("valid/F1/1", metgen_valid.metrics["F1"][1])
                tf.contrib.summary.scalar("valid/GM/0", metgen_valid.metrics["GM"][0])
                tf.contrib.summary.scalar("valid/MC/0", metgen_valid.metrics["MC"][0])                  
                tf.contrib.summary.scalar("valid/MK/0", metgen_valid.metrics["MK"][0])
                tf.contrib.summary.scalar("valid/ACC/mac", metgen_valid.metrics["ACC"]["macro"])                                                                            
                tf.contrib.summary.scalar("valid/PRC/mac", metgen_valid.metrics["PRC"]["macro"])
                tf.contrib.summary.scalar("valid/BAC/mac", metgen_valid.metrics["BAC"]["macro"])                                                            
                tf.contrib.summary.scalar("valid/F1/mac", metgen_valid.metrics["F1"]["macro"])
                tf.contrib.summary.scalar("valid/GM/mac", metgen_valid.metrics["GM"]["macro"])
                tf.contrib.summary.scalar("valid/MC/mac", metgen_valid.metrics["MC"]["macro"])                  
                tf.contrib.summary.scalar("valid/MK/mac", metgen_valid.metrics["MK"]["macro"])
                
                
                # Save model if metric is better than ever before, and test the model: Test flag.
                flag_save = False
                kwds_save = []
                for metric_tmp, dict_metric in metgen_valid.metrics.items():
                    for key, value in dict_metric.items():
                        # Nan to float
                        if np.isnan(value):
                            dict_metric[key] = -1. if (metric_tmp == "MC") or (metric_tmp == "MK") else 0.

                        # Conditioning and replace best values
                        if metgen_valid.metrics["SNS"][1] > 0.5:
                            if (metric_tmp=="SNS50") or (metric_tmp=="SNS70") or (metric_tmp=="SNS90") or \
                            (metric_tmp=="PRC50") or (metric_tmp=="PRC70") or (metric_tmp=="PRC90") or \
                            (metric_tmp=="REC50") or (metric_tmp=="REC70") or (metric_tmp=="REC90"):
                                continue
                            
                            else:
                                if dictbest.metrics[metric_tmp][key] < value:
                                    flag_save = True
                                    print("Iter: ", global_step.numpy())
                                    print("Best {} {} updated:".format(metric_tmp, key), dictbest.metrics[metric_tmp][key]," to ", value) #######
                                    dictbest.metrics[metric_tmp][key] = value
                                    kwds_save.append([metric_tmp, key]) 

                                    # For exp_phase="tuning", optuna
                                    if (metric_tmp == "GM") and (key == 0):
                                        best = value
                
                # Save model 2: ROC and PR
                if metgen_valid.metrics["SPC"][0] > 0.5:
                    metric_tmp = "SNS50"
                    key = 0
                    value = metgen_valid.metrics["SNS50"][0]

                    if dictbest.metrics[metric_tmp][key] < value:
                        flag_save = True
                        print("Iter: ", global_step.numpy())
                        print("Best {} {} updated:".format(metric_tmp, key), dictbest.metrics[metric_tmp][key]," to ", value) #######
                        dictbest.metrics[metric_tmp][key] = value
                        kwds_save.append([metric_tmp, key])

                        metric_tmp = "SPC50"
                        key = 0
                        value = metgen_valid.metrics["SPC50"][0]
                        dictbest.metrics[metric_tmp][key] = value
                        kwds_save.append([metric_tmp, key]) 
                        
                if metgen_valid.metrics["SPC"][0] > 0.7:
                    metric_tmp = "SNS70"
                    key = 0
                    value = metgen_valid.metrics["SNS70"][0]

                    if dictbest.metrics[metric_tmp][key] < value:
                        flag_save = True
                        print("Iter: ", global_step.numpy())
                        print("Best {} {} updated:".format(metric_tmp, key), dictbest.metrics[metric_tmp][key]," to ", value) #######
                        dictbest.metrics[metric_tmp][key] = value
                        kwds_save.append([metric_tmp, key]) 

                        metric_tmp = "SPC70"
                        key = 0
                        value = metgen_valid.metrics["SPC70"][0]
                        dictbest.metrics[metric_tmp][key] = value
                        kwds_save.append([metric_tmp, key]) 
                    
                if metgen_valid.metrics["SPC"][0] > 0.9:
                    metric_tmp = "SNS90"
                    key = 0
                    value = metgen_valid.metrics["SNS90"][0]

                    if dictbest.metrics[metric_tmp][key] < value:
                        flag_save = True
                        print("Iter: ", global_step.numpy())
                        print("Best {} {} updated:".format(metric_tmp, key), dictbest.metrics[metric_tmp][key]," to ", value) #######
                        dictbest.metrics[metric_tmp][key] = value
                        kwds_save.append([metric_tmp, key]) 

                        metric_tmp = "SPC90"
                        key = 0
                        value = metgen_valid.metrics["SPC90"][0]
                        dictbest.metrics[metric_tmp][key] = value
                        kwds_save.append([metric_tmp, key]) 
                    
                    
                if metgen_valid.metrics["SNS"][0] > 0.5:
                    metric_tmp = "PRC50"
                    key = 0
                    value = metgen_valid.metrics["PRC50"][0]

                    if dictbest.metrics[metric_tmp][key] < value:
                        flag_save = True
                        print("Iter: ", global_step.numpy())
                        print("Best {} {} updated:".format(metric_tmp, key), dictbest.metrics[metric_tmp][key]," to ", value) #######
                        dictbest.metrics[metric_tmp][key] = value
                        kwds_save.append([metric_tmp, key]) 

                        metric_tmp = "REC50"
                        key = 0
                        value = metgen_valid.metrics["REC50"][0]
                        dictbest.metrics[metric_tmp][key] = value
                        kwds_save.append([metric_tmp, key])
                    
                if metgen_valid.metrics["SNS"][0] > 0.7:
                    metric_tmp = "PRC70"
                    key = 0
                    value = metgen_valid.metrics["PRC70"][0]

                    if dictbest.metrics[metric_tmp][key] < value:
                        flag_save = True
                        print("Iter: ", global_step.numpy())
                        print("Best {} {} updated:".format(metric_tmp, key), dictbest.metrics[metric_tmp][key]," to ", value) #######
                        dictbest.metrics[metric_tmp][key] = value
                        kwds_save.append([metric_tmp, key]) 

                        metric_tmp = "REC70"
                        key = 0
                        value = metgen_valid.metrics["REC70"][0]
                        dictbest.metrics[metric_tmp][key] = value
                        kwds_save.append([metric_tmp, key])
                    
                if metgen_valid.metrics["SNS"][0] > 0.9:
                    metric_tmp = "PRC90"
                    key = 0
                    value = metgen_valid.metrics["PRC90"][0]

                    if dictbest.metrics[metric_tmp][key] < value:
                        flag_save = True
                        print("Iter: ", global_step.numpy())
                        print("Best {} {} updated:".format(metric_tmp, key), dictbest.metrics[metric_tmp][key]," to ", value) #######
                        dictbest.metrics[metric_tmp][key] = value
                        kwds_save.append([metric_tmp, key]) 

                        metric_tmp = "REC90"
                        key = 0
                        value = metgen_valid.metrics["REC90"][0]
                        dictbest.metrics[metric_tmp][key] = value
                        kwds_save.append([metric_tmp, key])
                    
                                        
                print("Keys of bests to be updated:\n", kwds_save)

                # If repleced any, save ckpt, calc test value, and update best_test values
                # Save ckpt
                if flag_save:
                    ckpt.save(dir_ckptpath + "/ckpt_epoch{}step{}_mnrvGM{:.7f}".format(
                        global_step.numpy(),
                        epoch,
                        metgen_valid.metrics["GM"][0]))
                    print("Saved ckpt: ", glob.glob(dir_ckptpath+"/*"))
                    print("")

                    # Test
                    ############################
                    if exp_phase == "stat":
                        ############################################
                        tmpls = []
                        #a = None
                        ############################################
                        for iter_irtest in ls_irtest:
                            # Lists for test
                            bs_list_test = []
                            loss_list_test = []
                            logit_list_test = []
                            y_testb_list = []
                            cm_list_test = []
                            ######################################################
                            #print("IRtest= ", iter_irtest)
                            #print(a == dic_x_test[iter_irtest][np.where(dic_y_test[iter_irtest]==0)[0]])
                            #a = dic_x_test[iter_irtest][np.where(dic_y_test[iter_irtest]==0)[0]]
                            ##print(np.where(dic_y_test[iter_irtest]==0))
                            ######################################################
                            # Test loop
                            print("Now testing...")
                            for iter_bv, (x_testb, y_testb) in enumerate(dataset_test[iter_irtest]):
                                if ((iter_bv+1)%10 == 0) or (iter_bv == 0):
                                    sys.stdout.write("\r iter: {}".format(iter_bv+1))
                                    sys.stdout.flush()
                                # Calc
                                loss_test, logits_test, _ = loss_fn(model, x_testb, y_testb, training=False)
                                cm_test, _ = cm_auc_fn(logits_test, y_testb, flag_auc=False)

                                # Append
                                bs_list_test.append(len(y_testb))
                                loss_list_test.append(loss_test)
                                logit_list_test.extend(logits_test)
                                y_testb_list.extend(y_testb)
                                cm_list_test.append(cm_test)
                            print("")
                            # Calc confusion matrix
                            cm_test_tot = np.sum(cm_list_test, axis=0)
                            metgen_test = MetricFunction(cm_test_tot)

                            #print("bs_list_test, loss_list_test: ", bs_list_test, loss_list_test) ###
                            loss_test = np.sum([i*j.numpy() for i,j in zip(bs_list_test, loss_list_test)]) / np.sum(bs_list_test)

                            #Verbose test 
                            print("Epoch {}: vloss={:.7f}, vACC={:.7f}, mnrvACC={:.7f}".format(
                                epoch+1,
                                loss_test, 
                                metgen_test.metrics["ACC"]["original"],
                                metgen_test.metrics["ACC"][0],
                                #rocauc_test['macro']
                            )
                                 )
                            print("Epoch {}: mnrvACC={:.7f}, mnrvSNS={:.7f}, mnrvF1={:.7f}, mnrvGM={:.7f}, mnrvMC={:.7f}, mnrvBAC={:.7f}".format(
                                epoch+1, 
                                metgen_test.metrics["ACC"][0],
                                metgen_test.metrics["SNS"][0],
                                metgen_test.metrics["F1"][0],
                                metgen_test.metrics["GM"][0],
                                metgen_test.metrics["MC"][0],
                                metgen_test.metrics["BAC"][0]
                                )
                            )
                            print("Total Confision Matirx for Test Data:")
                            print(cm_test_tot)
                            print("")

#                             # Tensorboard test
#                             tf.contrib.summary.scalar("test/IRtest{}/ACC/org".format(iter_irtest),
#                                                       metgen_test.metrics["ACC"]["original"])
#                             tf.contrib.summary.scalar("test/IRtest{}/ACC/0".format(iter_irtest),
#                                                       metgen_test.metrics["ACC"][0])
#                             tf.contrib.summary.scalar("test/IRtest{}/ACC/1".format(iter_irtest),
#                                                       metgen_test.metrics["ACC"][1])
#                             tf.contrib.summary.scalar("test/IRtest{}/SNS/0".format(iter_irtest),
#                                                       metgen_test.metrics["SNS"][0])
#                             tf.contrib.summary.scalar("test/IRtest{}/SNS/1".format(iter_irtest),
#                                                       metgen_test.metrics["SNS"][1])
#                             tf.contrib.summary.scalar("test/IRtest{}/SPC/0".format(iter_irtest),
#                                                       metgen_test.metrics["SPC"][0])
#                             tf.contrib.summary.scalar("test/IRtest{}/SPC/1".format(iter_irtest),
#                                                       metgen_test.metrics["SPC"][1])                                        
#                             tf.contrib.summary.scalar("test/IRtest{}/PRC/0".format(iter_irtest),
#                                                       metgen_test.metrics["PRC"][0])
#                             tf.contrib.summary.scalar("test/IRtest{}/PRC/1".format(iter_irtest),
#                                                       metgen_test.metrics["PRC"][1])
#                             tf.contrib.summary.scalar("test/IRtest{}/BAC/0".format(iter_irtest),
#                                                       metgen_test.metrics["BAC"][0])
#                             tf.contrib.summary.scalar("test/IRtest{}/BAC/1".format(iter_irtest),
#                                                       metgen_test.metrics["BAC"][1])                                                            
#                             tf.contrib.summary.scalar("test/IRtest{}/F1/0".format(iter_irtest),
#                                                       metgen_test.metrics["F1"][0])
#                             tf.contrib.summary.scalar("test/IRtest{}/F1/1".format(iter_irtest),
#                                                       metgen_test.metrics["F1"][1])
#                             tf.contrib.summary.scalar("test/IRtest{}/GM/0".format(iter_irtest),
#                                                       metgen_test.metrics["GM"][0])
#                             tf.contrib.summary.scalar("test/IRtest{}/GM/1".format(iter_irtest),
#                                                       metgen_test.metrics["GM"][1])
#                             tf.contrib.summary.scalar("test/IRtest{}/MC/0".format(iter_irtest),                                                 
#                                                       metgen_test.metrics["MC"][0])
#                             tf.contrib.summary.scalar("test/IRtest{}/MC/1".format(iter_irtest),
#                                                       metgen_test.metrics["MC"][1])                 
#                             tf.contrib.summary.scalar("test/IRtest{}/MK/0".format(iter_irtest),
#                                                       metgen_test.metrics["MK"][0])
#                             tf.contrib.summary.scalar("test/IRtest{}/MK/1".format(iter_irtest),
#                                                       metgen_test.metrics["MK"][1])
#                             tf.contrib.summary.scalar("test/IRtest{}/ACC/mac".format(iter_irtest),
#                                                       metgen_test.metrics["ACC"]["macro"])                    
#                             tf.contrib.summary.scalar("test/IRtest{}/SNS/mac".format(iter_irtest),
#                                                       metgen_test.metrics["SNS"]["macro"])                    
#                             tf.contrib.summary.scalar("test/IRtest{}/SPC/mac".format(iter_irtest),
#                                                       metgen_test.metrics["SPC"]["macro"])                                        
#                             tf.contrib.summary.scalar("test/IRtest{}/PRC/mic".format(iter_irtest),                                                   
#                                                       metgen_test.metrics["PRC"]["micro"])
#                             tf.contrib.summary.scalar("test/IRtest{}/PRC/mac".format(iter_irtest),
#                                                       metgen_test.metrics["PRC"]["macro"])
#                             tf.contrib.summary.scalar("test/IRtest{}/BAC/mic".format(iter_irtest),
#                                                       metgen_test.metrics["BAC"]["micro"])
#                             tf.contrib.summary.scalar("test/IRtest{}/BAC/mac".format(iter_irtest),                                              
#                                                       metgen_test.metrics["BAC"]["macro"])
#                             tf.contrib.summary.scalar("test/IRtest{}/F1/mic".format(iter_irtest),
#                                                       metgen_test.metrics["F1"]["micro"])
#                             tf.contrib.summary.scalar("test/IRtest{}/F1/mac".format(iter_irtest),                                                 
#                                                       metgen_test.metrics["F1"]["macro"])
#                             tf.contrib.summary.scalar("test/IRtest{}/GM/mic".format(iter_irtest),
#                                                       metgen_test.metrics["GM"]["micro"])
#                             tf.contrib.summary.scalar("test/IRtest{}/GM/mac".format(iter_irtest),
#                                                       metgen_test.metrics["GM"]["macro"])
#                             tf.contrib.summary.scalar("test/IRtest{}/MC/mic".format(iter_irtest),
#                                                       metgen_test.metrics["MC"]["micro"])
#                             tf.contrib.summary.scalar("test/IRtest{}/MC/mac".format(iter_irtest),
#                                                       metgen_test.metrics["MC"]["macro"])                   
#                             tf.contrib.summary.scalar("test/IRtest{}/MK/mic".format(iter_irtest),
#                                                       metgen_test.metrics["MK"]["micro"])
#                             tf.contrib.summary.scalar("test/IRtest{}/MK/mac".format(iter_irtest),
#                                                       metgen_test.metrics["MK"]["macro"])


                            # Nan to float
                            for key_mtr, dic_mtr in metgen_test.metrics.items():
                                for key_cls, val_tst in dic_mtr.items():
                                    if np.isnan(val_tst):
                                        dic_mtr[key_cls] = -1. if (key_mtr == "MC") or (key_mtr == "MK") else 0.

                            # Update best values
                            for _keys in kwds_save:
                                dictbest_test[iter_irtest].metrics[_keys[0]][_keys[1]] = metgen_test.metrics[_keys[0]][_keys[1]]
                            
                            
                            ##################################
                            tmpls.append((iter_irtest, metgen_test.metrics["SNS"][0]))
                            print("tmpls", tmpls)
                            ##################################
                            
                            
                                

    # Save best values to .db file ###
    for iter_irtest in ls_irtest:
        for key_metric, dict_bests in dictbest_test[iter_irtest].metrics.items():
            for key_cls, value in dict_bests.items():
                trial.set_user_attr("{}_{}_IRtest{}".format(key_metric, key_cls, iter_irtest), value)
                print("Final bests: {}_{}_IRtest{}, {}".format(key_metric, key_cls, iter_irtest, value))

    # Return objective value
    if exp_phase == "stat":
        return 1.
    elif exp_phase == "tuning":
        return 1. - best
    else:
        raise ValueError


# %%
if __name__ == '__main__':
    #study = optuna.create_study(sampler=optuna.samplers.TPESampler())
    study = optuna.load_study(study_name=study_name, storage=storage)
    study.optimize(objective, n_trials=n_trials)
    print("IR ", IR)
    print("best_params")
    print(study.best_params)
    print("1 - best_value")
    print(1 - study.best_value)
    
    print("\n --- sorted --- \n")
    sorted_best_params = sorted(study.best_params.items(), key=lambda x : x[0])
    for i, k in sorted_best_params:
        print(i + " : " + str(k))
