import torch
from torch import nn
from torch.autograd import Variable
from torch.optim import lr_scheduler


class Ensembler():
    def __init__(self, main_model_probability_transfer_matrix,
                 assist_model_probability_transfer_matrix
                 , device="cuda:0"):
        self.main_model_probability_transfer_matrix = main_model_probability_transfer_matrix
        self.assist_model_probability_transfer_matrix = assist_model_probability_transfer_matrix
        self.device = device

    def ensemble(self, main_model_generate_ids_logits, assist_model_generate_ids_logits, learning_epochs_nums=5,
                 learning_rate=0):
        if abs(learning_rate) > 1e-6:
            main_model_generate_ids_logits = Variable(main_model_generate_ids_logits, requires_grad=True)
            with torch.no_grad():
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    self.main_model_probability_transfer_matrix).to(
                    self.device)

                assist_model_generate_ids_probs = nn.functional.softmax(assist_model_generate_ids_logits,
                                                                        dim=-1).float()
                assist_model_relative_representation_probs = torch.mm(assist_model_generate_ids_probs,
                                                                      self.assist_model_probability_transfer_matrix).to(
                    self.device)

            average_probs = (main_model_relative_representation_probs +
                             assist_model_relative_representation_probs) / 2

            torch.set_grad_enabled(True)
            main_model_generate_ids_logits = main_model_generate_ids_logits.to(self.device).detach().clone()
            main_model_generate_ids_logits.requires_grad_(True)

            criterion = nn.KLDivLoss()
            optimizer = torch.optim.AdamW(params=[main_model_generate_ids_logits],
                                          lr=learning_rate,
                                          betas=(0.9, 0.999))

            scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=learning_rate / 4)

            for _ in range(learning_epochs_nums):
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    self.main_model_probability_transfer_matrix)

                log_main_probs = torch.log(main_model_relative_representation_probs)
                loss = criterion(log_main_probs, average_probs)

                # print(loss.item())
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()

            torch.set_grad_enabled(False)

            return main_model_generate_ids_logits.to(self.device).detach()
        else:
            return main_model_generate_ids_logits.to(self.device).detach()


