# -*- coding: utf-8 -*-
# !/usr/bin/python

import sys
import time
import torch
sys.path.append("..")
import random
from baselines.utils import RANDOM
from baselines.basic_trainer import BasicTrainer


class EMRTrainer(BasicTrainer):
    def __init__(self, args, model_save_path):
        super(EMRTrainer, self).__init__(args, model_save_path)

        # CL component
        self.past_task_id = -1
        self.observed_task_ids = []
        self.memory_data = {}  # stores exemplars class by class

    def train(self):
        for i in range(self.args.task_num):
            best_result = {"acc": 0.0, "epoch": 0}
            examples = self.task_controller.task_list[i]["train"]

            n_epochs = self.args.epoch
            epoch_eval = self.args.epoch_eval

            patience = 0

            if i != self.past_task_id:
                self.observed_task_ids.append(i)
                self.past_task_id = i

            self.memory_data[i] = []
            sampled_examples = RANDOM(examples=examples,
                                      memory_size=self.args.memory_size)
            self.memory_data[i].extend(sampled_examples)

            for epoch in range(n_epochs):
                self.model.train()
                epoch_begin = time.time()
                random.shuffle(examples)
                st = 0
                report_loss, example_num = 0.0, 0
                cnt = 0
                self.optimizer.zero_grad()

                while st < len(examples):
                    ed = st + self.args.batch_size if st + self.args.batch_size < len(examples) else len(examples)
                    report_loss, example_num, loss = self.train_one_batch(examples[st:ed], report_loss, example_num)

                    loss.backward()
                    if (cnt + 1) % self.args.accumulation_step == 0 or ed == len(examples):

                        if len(self.observed_task_ids) > 1:
                            for _task_id in range(len(self.observed_task_ids) - 1):
                                past_task_id = self.observed_task_ids[_task_id]
                                replay_examples = random.sample(self.memory_data[past_task_id],
                                                                min(len(self.memory_data[past_task_id]), self.args.batch_size))

                                assert past_task_id != i

                                random.shuffle(replay_examples)
                                replay_report_loss = 0.0
                                _st = 0
                                replay_example_num = 0

                                while _st < len(replay_examples):
                                    _ed = _st + self.args.batch_size if _st + self.args.batch_size < len(
                                        replay_examples) else len(replay_examples)

                                    replay_report_loss, replay_example_num, replay_loss = self.train_one_batch(
                                        replay_examples[_st:_ed],
                                        replay_report_loss,
                                        replay_example_num)
                                    replay_loss.backward()
                                    report_loss += replay_report_loss
                                    _st = _ed

                        if self.args.clip_grad > 0.:
                            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_grad)
                        self.optimizer.step()
                        self.optimizer.zero_grad()
                    st = ed
                    cnt += 1

                print("\nTask {}, Epoch Train {}, Loss {}, Time {}".format(i, epoch, report_loss / example_num, time.time() - epoch_begin))

                if epoch < epoch_eval:
                    continue

                start_time = time.time()
                dev_acc, beam_acc, (right, wrong, _), write_data = self.epoch_acc(
                    self.task_controller.task_list[i]["dev"])
                print('Evaluation: \tEpoch: %d\tTime: %.4f\tDev acc: %.4f\n' % (epoch, time.time() - start_time, dev_acc))

                if dev_acc >= best_result['acc']:
                    best_result['acc'], best_result['epoch'] = dev_acc, epoch
                    self.save(self.model, name="model.bin")
                    patience = 0
                else:
                    patience += 1

                if patience > self.args.max_patience:
                    break

            self.load(self.model)
            start_time = time.time()
            test_acc, beam_acc, (right, wrong, _), write_data = self.epoch_acc(
                self.task_controller.task_list[i]["test"])
            print('Evaluation: \tTime: %.4f\tTest acc: %.4f\n' % (time.time() - start_time, test_acc))

            self.first_acc_list[i] = test_acc
            self.eval_task_stream(i, test_acc)

        return self.avg_acc_list, self.whole_acc_list, self.bwt_list, self.fwt_list
