import numpy as onp
from jax_utils import load_from_cache_or_compute, get_loss, get_argparse, append
import jax_utils as ju
import jax.random as random
import jax.numpy as np
import neural_tangents as nt
from data import get_datasets
from jax.example_libraries import optimizers
from jax import jit, grad, vmap
from pathlib import Path
import os
import math
import copy

def train_network_sgd(params, apply_fn, train_data, train_labels, threshold, arg, key):
    learning_rate = arg.lr
    opt_init, opt_update, get_params = optimizers.momentum(learning_rate, mass=arg.momentum)
    opt_update = jit(opt_update)
    loss_fn = get_loss(arg.loss)
    loss = jit(lambda params, x, y: loss_fn(apply_fn(params, x), y))
    grad_loss = jit(lambda state, x, y: grad(loss)(get_params(state), x, y))

    opt_state = opt_init(params)
    iter=0
    train_loss = threshold + 10
    batchsize = arg.batch_sgd
    while train_loss > threshold:
        total_loss=0
        if batchsize<0:
            batchsize=train_data.shape[0]
        randperm = onp.random.permutation(train_data.shape[0])
        for batch in range(math.ceil(train_data.shape[0]/batchsize)):
            id_range=randperm[batchsize*batch:min(train_data.shape[0], batchsize*(batch+1))]
            x_batch = train_data[id_range]
            y_batch = train_labels[id_range]
            curr_loss = np.reshape(loss(get_params(opt_state), x_batch, y_batch), (1,))[0]
            if batch % 100 == 0:
                print(batch, curr_loss)
            total_loss += curr_loss*len(id_range)
            opt_state = opt_update(iter, grad_loss(opt_state, x_batch, y_batch), opt_state)
        iter += 1
        train_loss = total_loss / train_data.shape[0]
        print("Iteration", iter, train_loss)
    print(key, "Train loss",train_loss)

    return get_params(opt_state)

if __name__ == '__main__':
    parser = get_argparse()
    arg = parser.parse_args()
    run_id = parser.print_arg(arg)
    dir_name=arg.out_dir+"/"+run_id
    print(dir_name)

    dataset=arg.dataset
    num_train_data=arg.num_train_data
    num_ensemble=arg.num_ensemble
    bias=arg.bias if arg.bias >=0 else None
    activation=arg.activation
    hidden_depths = [int(d) for d in arg.hidden_depths.split(",")]
    hidden_widths = [int(w) for w in arg.hidden_widths.split(",")]
    train_dloader, ind_dloader, ood_dloaders, in_shape, out_shape = get_datasets(dataset, num_train_data, num_test_data=min(num_train_data,10000) if arg.num_test_data < 0 else arg.num_test_data, binary=arg.binary)

    train_data, train_labels = train_dloader.dataset
    train_data = np.array(train_data.permute([0,2,3,1]).numpy())
    train_labels = np.array(train_labels.numpy())

    test_data, test_labels = ind_dloader.dataset
    test_data = np.array(test_data.permute([0,2,3,1]).numpy())
    test_labels = np.array(test_labels.numpy())

    mode='ntk'
    result=dict()
    for hidden_depth in hidden_depths:
        print("Depth", hidden_depth)
        result[hidden_depth]=dict()
        for hidden_width in hidden_widths:
            print("Width", hidden_width)
            print("Training")

            result[hidden_depth][hidden_width]=dict()
            curr_result = result[hidden_depth][hidden_width]

            if arg.net=='mlp':
                init_fn, apply_fn, kernel_fn = ju.get_mlp(out_shape, hidden_width, hidden_depth, bias=arg.bias, activation=arg.activation)
            elif arg.net == 'conv':
                init_fn, apply_fn, kernel_fn = ju.get_miniminiconv(out_shape, hidden_width, hidden_depth, bias=arg.bias, activation=arg.activation)
            elif arg.net == 'wrn':
                init_fn, apply_fn, kernel_fn = ju.get_wrn(out_shape, hidden_width, hidden_depth, bias=arg.bias, activation=arg.activation, data=dataset)

            vars=dict()
            preds=dict()

            key0 = random.PRNGKey(0)
            _, params0 = init_fn(key0, input_shape=train_data.shape)

            print("Precomputing Kernels")
            ntk_trained_norm = 0
            for specialist in range(arg.first_specialist, arg.first_specialist+arg.num_ensemble):
                print(specialist)
                key = random.PRNGKey(specialist)
                _, params = init_fn(key, input_shape=train_data.shape)
                rand_key = random.PRNGKey(1000 + specialist)
                _, rand_params = init_fn(rand_key, input_shape=train_data.shape)

                lin_apply_p0=nt.linearize(apply_fn, params0)
                lin_apply_prand=nt.linearize(apply_fn, rand_params)

                if arg.compute_gd_ntk_change:
                    emp_kernel_fn = nt.empirical_kernel_fn(apply_fn)
                    emp_kernel_fn = nt.batch(emp_kernel_fn, batch_size=arg.batch if arg.batch > 0 else arg.num_train_data)
                    emp_ntk_trtr = emp_kernel_fn(train_data, train_data, "ntk", params)

                apply_fn_factory = lambda _apply_fn, prm_sub, prm_add: (lambda prm, x: _apply_fn(prm, x) - (_apply_fn(prm_sub, x) if prm_sub is not None else 0) + (_apply_fn(prm_add, x) if prm_add is not None else 0))
                lin_apply=nt.linearize(apply_fn, params)

                apply_fns = [("", apply_fn_factory(apply_fn, None, None), copy.deepcopy(params)),
                             ("decorr",apply_fn_factory(apply_fn, copy.deepcopy(params), copy.deepcopy(rand_params)), copy.deepcopy(params)),
                             ("zero", apply_fn_factory(apply_fn, copy.deepcopy(params), None), copy.deepcopy(params)),
                             ("inf",apply_fn_factory(apply_fn, copy.deepcopy(params0), copy.deepcopy(params)), copy.deepcopy(params0)),
                             ("_lin", lin_apply, copy.deepcopy(params)),
                             ("decorr_lin", (lambda prm, x: lin_apply(prm, x) - lin_apply(params, x) + apply_fn(rand_params, x) ), copy.deepcopy(params)),
                             ("zero_lin", apply_fn_factory(lin_apply, copy.deepcopy(params), None),copy.deepcopy(params)),
                             ("inf_lin", (lambda prm, x: lin_apply_p0(prm, x) - lin_apply_p0(params0, x) + apply_fn(params, x) ), copy.deepcopy(params0))
                            ]

                if arg.no_lin:
                    apply_fns = apply_fns[:5]

                for strkey, curr_apply_fn, prm in apply_fns:
                    if ("gd"+strkey) not in preds:
                        preds['gd'+strkey]=[[] for k in range(len(ood_dloaders)+2)]
                    print("Training D{}W{} specialist {} {}".format(hidden_depth, hidden_width, specialist, "gd"+strkey))

                    gd_params=prm
                    for th in [float(t) for t in arg.threshold.split(",")]:
                        th_str = "/th_{}".format(th)
                        gd_params = load_from_cache_or_compute(dir_name + "/d{}_w{}".format(hidden_depth, hidden_width)+th_str, "gd{}_{}".format(strkey, specialist),
                                                               train_network_sgd, gd_params, curr_apply_fn, train_data, train_labels, th, arg, key)

                        if strkey == "decorr" and arg.compute_gd_ntk_change:
                            trained_emp_ntk_trtr = emp_kernel_fn(train_data, train_data, "ntk", gd_params)
                            ntk_trained_norm += np.linalg.norm(emp_ntk_trtr - trained_emp_ntk_trtr)/np.linalg.norm(emp_ntk_trtr)/arg.num_ensemble

                    for i, dloader in enumerate([train_dloader, ind_dloader] + ood_dloaders):
                        test_data=np.array(dloader.dataset[0].permute([0,2,3,1]).numpy())

                        batchsize = arg.batch_sgd
                        if batchsize<0:
                            batchsize=test_data.shape[0]
                        curr_pred = []
                        for batch in range(math.ceil(test_data.shape[0]/batchsize)):
                            x_batch = test_data[batchsize*batch:min(test_data.shape[0], batchsize*(batch+1))]
                            curr_pred.append(curr_apply_fn(gd_params, x_batch))
                        preds["gd"+strkey][i].append(np.concatenate(curr_pred, axis=0))
                        print(preds["gd"+strkey][i][-1].shape)
            append(curr_result, "ntk_trained_norm", ntk_trained_norm)

            for key in preds:
                preds[key] = [np.stack(preds[key][dset], axis=1) for dset in range(len(preds[key]))]
                vars[key] = [np.sum(np.var(preds[key][dset], axis=1), axis=1) for dset in range(len(preds[key]))]

            loss_fn=get_loss(arg.loss)
            for key in preds:
                for idx, dloader in enumerate([train_dloader, ind_dloader]):
                    test_labels=np.array(dloader.dataset[1].numpy())
                    append(curr_result, 'loss_ens_'+key, loss_fn(np.mean(preds[key][idx], axis=1), test_labels))
                    append(curr_result, 'loss_spec_'+key, onp.mean([loss_fn(preds[key][idx][:,i], test_labels) for i in range(arg.num_ensemble)]))
                    append(curr_result, 'acc_ens_'+key, ju.acc_fn(np.mean(preds[key][idx], axis=1), test_labels))
                    append(curr_result, 'acc_spec_'+key, onp.mean([ju.acc_fn(preds[key][idx][:,i], test_labels) for i in range(arg.num_ensemble)]))

            for key in preds:
                if key+"_lin" not in preds:
                    continue
                key_lin=key + "_lin"
                [append(curr_result, "abs_wrt_gd_pred_" + key, np.mean(np.abs(preds[key_lin][i]-preds[key][i]))) for i in range(len(preds[key_lin]))]
                [append(curr_result, "mse_wrt_gd_pred_vs_gd_predvar_" + key, np.mean(np.mean((preds[key_lin][i]-preds[key][i])**2, axis=(1,2))/vars[key][i])) for i in range(len(preds[key_lin]))]
                [append(curr_result, "mse_mean_wrt_mean_gd_pred_" + key, np.mean((np.mean(preds[key_lin][i], axis=1)-np.mean(preds[key][i], axis=1))**2))for i in range(len(preds[key_lin]))]
                [append(curr_result, "abs_var_wrt_var_gd_pred_" + key, np.mean(np.abs(vars[key_lin][i]-vars[key][i]))) for i in range(len(preds[key_lin]))]
                [append(curr_result, "mse_var_wrt_var_gd_pred_vs_var_gd_pred_" + key, np.mean((vars[key_lin][i]-vars[key][i])**2/vars[key][i]**2)) for i in range(len(preds[key_lin]))]

            # Predvar and auroc
            for key in vars:
                [append(curr_result, "predvar_"+key, np.mean(v)) for v in vars[key] ]
                for i, v in enumerate(vars[key][2:]):
                    auroc =  ju.compute_auroc(onp.array(vars[key][1]), onp.array(v))
                    append(curr_result, "auroc_"+key, auroc)

            # [append(curr_result, "PRED_{}".format(key), np.mean(np.concatenate(preds[key], axis=0), axis=1)) for key in vars  if "rand_decorr" in key ]
            # [append(curr_result, "PREDVAR_{}".format(key), np.concatenate(vars[key], axis=0)) for key in vars  if "rand_decorr" in key ]

            [append(curr_result, "var_linearization_error_vs_finitewidth_error", \
                          (np.exp(np.mean(np.log((np.abs(vars["gddecorr"][i]-vars["gddecorr_lin"][i])/(vars["gdzero_lin"][i]))))))) for i in range(len(preds["gddecorr"]))]
            [append(curr_result, "mean_linearization_error_vs_finitewidth_error", \
                          (np.exp(np.mean(np.exp(((np.mean(preds["gddecorr"][i], axis=1)-np.mean(preds["gddecorr_lin"][i], axis=1))**2 \
                          /(np.mean(preds["gdzero_lin"][i], axis=1))**2)))))) for i in range(len(preds["gddecorr"]))]

            Path(dir_name).mkdir(parents=True, exist_ok=True)
            with open(os.path.join(dir_name, "sgd_d{}_w{}_ens{}-{}_th{}_result.txt".format(hidden_depth, hidden_width, arg.first_specialist, arg.num_ensemble, arg.threshold.split(",")[-1])), 'w') as f:
                for k in curr_result:
                    line="{}\t{}\t{}\t{}".format(hidden_depth, hidden_width, k, "\t".join([str(f) for f in curr_result[k]]))
                    print(line)
                    f.write(line+"\n")

    for k1 in result:
        for k2 in result[k1]:
            for k3 in result[k1][k2]:
                print("{}\t{}\t{}\t{}".format(k1,k2,k3, "\t".join([str(f) for f in result[k1][k2][k3]])))

