from neuralalgo.data_generator import QuaOptDataset
from neuralalgo.common.cmd_args import cmd_args
from neuralalgo.common.consts import DEVICE
from neuralalgo.model_algos import GD, NAG, MLPRNN
from neuralalgo.trainer import AlgoTrainer
import random
import numpy as np
import torch
import torch.optim as optim
import pickle


def append_results(results, args, train_loss, test_loss, s):

    results.setdefault('algo', []).append(args.algo_type)
    results.setdefault('s', []).append(s)
    results.setdefault('k', []).append(args.k)
    results.setdefault('train_loss', []).append(train_loss)
    results.setdefault('test_loss', []).append(test_loss)
    results.setdefault('gap', []).append(test_loss-train_loss)
    results.setdefault('mu', []).append(args.mu)
    results.setdefault('L', []).append(args.L)
    results.setdefault('num_train', []).append(args.num_train)
    results.setdefault('num_test', []).append(args.num_test)
    results.setdefault('epoch', []).append(args.num_epochs)
    results.setdefault('optimizer', []).append(args.optimizer)
    results.setdefault('lr', []).append(args.learning_rate)
    results.setdefault('dc', []).append(args.weight_decay)
    results.setdefault('seed', []).append(args.seed_train)
    results.setdefault('dump', []).append(args.model_dump)

    return results


def save_to_pkl(args, train_err, gen_gap, s, filename='results'):

    train_loss = train_err
    test_loss = train_err + gen_gap

    print('train:%.3f,gap:%.3f,s:%.3f' % (train_loss, test_loss-train_loss, s))

    this_result = append_results({}, args, train_loss, test_loss, s)
    save_result_filename = filename + '.pkl'
    with open(save_result_filename, 'ab') as handle:
        pickle.dump(this_result, handle, protocol=pickle.HIGHEST_PROTOCOL)

    best_result_filename = filename + '_best.pkl'
    with open(best_result_filename, 'rb') as handle:
        best_results = pickle.load(handle)

    # replace best result
    if args.num_train not in best_results:
        best_results[args.num_train] = {}
    if args.algo_type not in best_results[args.num_train]:
        best_results[args.num_train][args.algo_type] = {}
    if args.k not in best_results[args.num_train][args.algo_type]:
        best_results[args.num_train][args.algo_type][args.k] = {}
    if args.seed_train not in best_results[args.num_train][args.algo_type][args.k]:
        # save if this is the first result
        best_results[args.num_train][args.algo_type][args.k][args.seed_train] = this_result
        with open(best_result_filename, 'wb') as handle:
            pickle.dump(best_results, handle, protocol=pickle.HIGHEST_PROTOCOL)
    else:
        # compare training loss
        current_best_dict = best_results[args.num_train][args.algo_type][args.k][args.seed_train]
        if train_loss < current_best_dict['train_loss'][0]:
            # replace if better
            best_results[args.num_train][args.algo_type][args.k][args.seed_train] = this_result
            with open(best_result_filename, 'wb') as handle:
                pickle.dump(best_results, handle, protocol=pickle.HIGHEST_PROTOCOL)


if __name__ == '__main__':
    random.seed(cmd_args.seed)
    np.random.seed(cmd_args.seed)
    torch.manual_seed(cmd_args.seed)
    
    # initialize train and test data
    db = QuaOptDataset(mu=cmd_args.mu,
                       L=cmd_args.L,
                       d=cmd_args.d,
                       population=cmd_args.num_test,
                       train=cmd_args.num_train)

    # resample training data
    db.resample_train(cmd_args.seed_train)

    # initialize the network
    if cmd_args.algo_type == 'gd':
        algo = GD(k=cmd_args.k, init_s=cmd_args.init_s).to(DEVICE)
    elif cmd_args.algo_type == 'nag':
        algo = NAG(k=cmd_args.k, mu=cmd_args.mu, init_s=cmd_args.init_s).to(DEVICE)
    else:
        assert cmd_args.algo_type == 'mlp_rnn'
        algo = MLPRNN(k=cmd_args.k, hidden_dim=cmd_args.mlp_hidden_dims, d=cmd_args.d, activation=cmd_args.activation).to(DEVICE)

    # training
    if cmd_args.phase == 'train':

        if cmd_args.optimizer == 'sgd':
            optimizer = optim.SGD(algo.parameters(),
                                  lr=cmd_args.learning_rate,
                                  weight_decay=cmd_args.weight_decay)
        else:
            assert cmd_args.optimizer == 'adam'
            optimizer = optim.Adam(algo.parameters(),
                                   lr=cmd_args.learning_rate,
                                   weight_decay=cmd_args.weight_decay)

        trainer = AlgoTrainer(cmd_args, db, algo, optimizer, dump=False)
        # reset seed before
        random.seed(cmd_args.seed2)
        np.random.seed(cmd_args.seed2)
        torch.manual_seed(cmd_args.seed2)

        best_loss, gap, best_s = trainer.train()
        save_to_pkl(cmd_args, best_loss, gap, best_s, cmd_args.filename)
