# -*- coding: utf-8 -*-
# !/usr/bin/python

import sys
import time
import torch
sys.path.append("..")
import random
from baselines.basic_trainer import BasicTrainer
from baselines.utils import RANDOM, store_grad, project2cone2, overwrite_grad


class GEMTrainer(BasicTrainer):
    def __init__(self, args, model_save_path):
        super(GEMTrainer, self).__init__(args, model_save_path)

        # CL component
        self.past_task_id = -1
        self.observed_task_ids = []
        self.memory_data = {}  # stores exemplars class by class

        self.grad_dims = []
        for name, param in self.model.named_parameters():
            if "plm_model" in name:
                continue
            # print(name)
            self.grad_dims.append(param.data.numel())

        self.grads = torch.Tensor(sum(self.grad_dims), args.task_num)
        if args.cuda:
            self.grads = self.grads.cuda()

        # print(self.grads.size())
        # print(len(self.grad_dims))
        # print(self.grad_dims)

    def train(self):

        for i in range(self.args.task_num):
            best_result = {"acc": 0.0, "epoch": 0}
            examples = self.task_controller.task_list[i]["train"]

            n_epochs = self.args.epoch
            epoch_eval = self.args.epoch_eval

            patience = 0

            if i != self.past_task_id:
                self.observed_task_ids.append(i)
                self.past_task_id = i

            for epoch in range(n_epochs):
                self.model.train()
                epoch_begin = time.time()
                random.shuffle(examples)
                st = 0
                report_loss, example_num = 0.0, 0
                cnt = 0
                self.optimizer.zero_grad()

                while st < len(examples):

                    # training on the batch of current task
                    ed = st + self.args.batch_size if st + self.args.batch_size < len(examples) else len(examples)

                    report_loss, example_num, loss = self.train_one_batch(examples[st:ed], report_loss, example_num)
                    loss.backward()

                    if (cnt + 1) % self.args.accumulation_step == 0 or ed == len(examples):
                        # gem
                        if len(self.observed_task_ids) > 1:
                            # copy gradient
                            store_grad(self.model.named_parameters,
                                       self.grads,
                                       self.grad_dims, i)

                            for _task_id in range(len(self.observed_task_ids) - 1):
                                self.optimizer.zero_grad()
                                past_task_id = self.observed_task_ids[_task_id]
                                replay_examples = random.sample(self.memory_data[past_task_id],
                                                                min(len(self.memory_data[past_task_id]), self.args.batch_size))

                                assert past_task_id != i

                                random.shuffle(replay_examples)
                                replay_report_loss = 0.0
                                replay_example_num = 0
                                _st = 0

                                while _st < len(replay_examples):
                                    _ed = _st + self.args.batch_size if _st + self.args.batch_size < len( replay_examples) else len(replay_examples)

                                    replay_report_loss, replay_example_num, replay_loss = self.train_one_batch(replay_examples[_st:_ed],
                                                                                                               replay_report_loss,
                                                                                                               replay_example_num)
                                    replay_loss.backward()
                                    report_loss += replay_report_loss
                                    _st = _ed

                                store_grad(self.model.named_parameters,
                                           self.grads,
                                           self.grad_dims,
                                           past_task_id)
                            self.optimizer.zero_grad()

                            indx = torch.cuda.LongTensor(self.observed_task_ids[:-1])
                            dotp = torch.mm(self.grads[:, i].unsqueeze(0),
                                            self.grads.index_select(1, indx))

                            if (dotp < 0).sum() != 0:
                                project2cone2(self.grads[:, i].unsqueeze(1),
                                              self.grads.index_select(1, indx),
                                              float(self.args.gem_margin))
                                # copy gradients back
                                overwrite_grad(self.model.named_parameters(),
                                               self.grads[:, i],
                                               self.grad_dims)

                        if self.args.clip_grad > 0.:
                            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_grad)
                        self.optimizer.step()
                        self.optimizer.zero_grad()

                    st = ed
                    cnt += 1

                print("\nTask {}, Epoch Train {}, Loss {}, Time {}".format(i, epoch, report_loss, time.time() - epoch_begin))

                if epoch < epoch_eval:
                    continue

                start_time = time.time()
                dev_acc, beam_acc, (right, wrong, _), write_data = self.epoch_acc(self.task_controller.task_list[i]["dev"])
                print('Evaluation: \tEpoch: %d\tTime: %.4f\tDev acc: %.4f\n' % (epoch, time.time() - start_time, dev_acc))

                if dev_acc >= best_result['acc']:
                    best_result['acc'], best_result['epoch'] = dev_acc, epoch
                    self.save(self.model, name="model.bin")
                    patience = 0
                else:
                    patience += 1

                if patience > self.args.max_patience:
                    break
            self.load(self.model)
            start_time = time.time()
            test_acc, beam_acc, (right, wrong, _), write_data = self.epoch_acc(
                self.task_controller.task_list[i]["test"])
            print('Evaluation: \tTime: %.4f\tTest acc: %.4f\n' % (time.time() - start_time, test_acc))

            self.first_acc_list[i] = test_acc
            self.eval_task_stream(i, test_acc)

            self.memory_data[i] = []
            sampled_examples = RANDOM(examples=examples,
                                      memory_size=self.args.memory_size)
            self.memory_data[i].extend(sampled_examples)

        return self.avg_acc_list, self.whole_acc_list, self.bwt_list, self.fwt_list
