import json
import logging
import math
import os
import pdb
import queue

import torch
from torch import nn
from torch.autograd import Variable
from torch.optim import lr_scheduler
from transformers import LogitsProcessor


class BasedOnProbabilityTransferLogits_Loacal_0514_Record_Processor(LogitsProcessor):
    def __init__(self, learning_rate, anchor_point_count, learning_epochs_nums,
                 ensemble_model_output_ids_queue, assist_model_score_queue_list,
                 main_model_probability_transfer_matrix_list,
                 assist_model_probability_transfer_matrix_list, result_save_dir, main_model_tokenizer,
                 assist_model_tokenizer, device, device_compute, forced_eos_token_id, early_stop_string_list=None):
        self.learning_rate = learning_rate
        self.anchor_point_count = anchor_point_count
        self.assist_model_score_queue_list = assist_model_score_queue_list
        self.learning_epochs_nums = learning_epochs_nums
        self.ensemble_model_output_ids_queue = ensemble_model_output_ids_queue
        self.main_model_probability_transfer_matrix_list = main_model_probability_transfer_matrix_list
        self.assist_model_probability_transfer_matrix_list = assist_model_probability_transfer_matrix_list
        self.result_save_dir = result_save_dir
        self.main_model_tokenizer = main_model_tokenizer
        self.assist_model_tokenizer = assist_model_tokenizer
        self.device = device
        self.device_compute = device_compute
        self.forced_eos_token_id = forced_eos_token_id
        self.early_stop_string_list = early_stop_string_list

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        ensemble_process_file_path = os.path.join(self.result_save_dir,
                                                  f'ensemble_lr{self.learning_rate}_anchor_point_count_all_learning_epochs_nums_5.log')
        main_model_only_flag = False
        
        json_object = {}

        values, indices = torch.topk(scores, k=10)
        json_object[f'main_values'] = values.tolist()[0]
        json_object[f'main_indices'] = indices.tolist()[0]

        assist_model_generate_ids_logits_list = []
        for index, queue_instance in enumerate(self.assist_model_score_queue_list):
            try:
                value = queue_instance.get(block=True, timeout=5)
                assist_model_generate_ids_logits_list.append(value)
                values, indices = torch.topk(value, k=10)
                json_object[f'aux_values_{index}'] = values.tolist()[0]
                json_object[f'aux_indices_{index}'] = indices.tolist()[0]

            except queue.Empty:
                print(f"aux model{index}【not received】\n")
                assist_model_generate_ids_logits_list.append(None)
                main_model_only_flag = True

        if math.fabs(self.learning_rate) <= 1e-6:
            main_model_only_flag = True
        if torch.argmax(scores).item() == self.forced_eos_token_id:
            main_model_only_flag = True
        if self.early_stop_string_list is not None:
            for early_stop_string in self.early_stop_string_list:
                early_stop_token = self.main_model_tokenizer(early_stop_string, return_tensors="pt",
                                                             add_special_tokens=False).input_ids.tolist()[0][1:]
                last_token_count = len(early_stop_token)

                last_token_ids = input_ids.tolist()[0][-last_token_count:]
                if last_token_ids == early_stop_token:
                    scores[:, self.forced_eos_token_id] = float('inf')
                    main_model_only_flag = True

        if not main_model_only_flag:

            main_model_generate_ids_logits = Variable(scores, requires_grad=True).to(torch.float32)
            # pdb.set_trace()
            with torch.no_grad():
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                # pdb.set_trace()
                local_main_model_relative_representation_matrix = self.main_model_probability_transfer_matrix_list[0]
                local_main_model_relative_reverse_mapping_representation_matrix = \
                    self.main_model_probability_transfer_matrix_list[1]

                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix).to(
                    self.device_compute)

                main_model_generate_ids_probs_values, main_model_generate_ids_probs_indices = torch.topk(
                    main_model_generate_ids_probs, k=10)
                json_object[f'main_model_generate_ids_probs_values'] = main_model_generate_ids_probs_values.tolist()[0]
                json_object[f'main_model_generate_ids_probs_indices'] = main_model_generate_ids_probs_indices.tolist()[
                    0]

                main_model_relative_values, main_model_relative_indices = torch.topk(
                    main_model_relative_representation_probs, k=10)
                json_object[f'main_rel_values'] = main_model_relative_values.tolist()[0]
                json_object[f'main_rel_indices'] = main_model_relative_indices.tolist()[0]

            model_relative_representation_probs_list = [main_model_relative_representation_probs]
            with torch.no_grad():
                for index, (assist_model_generate_ids_logits, assist_model_probability_transfer_matrix) in enumerate(
                        zip(
                            assist_model_generate_ids_logits_list,
                            self.assist_model_probability_transfer_matrix_list)):
                    
                    assist_model_generate_ids_probs = nn.functional.softmax(assist_model_generate_ids_logits,
                                                                            dim=-1).float()
                    # 
                    local_assist_model_relative_representation_matrix = assist_model_probability_transfer_matrix
                    assist_model_relative_representation_probs = torch.mm(assist_model_generate_ids_probs,
                                                                          local_assist_model_relative_representation_matrix).to(
                        self.device_compute)

                    assist_model_relative_values, assist_model_relative_indices = torch.topk(
                        assist_model_relative_representation_probs, k=10)
                    json_object[f'aux_rel_values_{index}'] = assist_model_relative_values.tolist()[0]
                    json_object[f'aux_rel_indices_{index}'] = assist_model_relative_indices.tolist()[0]

                    model_relative_representation_probs_list.append(assist_model_relative_representation_probs)

            model_weights = [0.25, 0.25, 0.25, 0.25]
            assert len(model_weights) == len(model_relative_representation_probs_list), "权重数和logits必须相同"
            average_probs = torch.zeros_like(main_model_relative_representation_probs)
            for weight, probs in zip(model_weights, model_relative_representation_probs_list):
                average_probs += weight * probs
            final_average_probs = average_probs

            relative_representation_values, relative_representation_indices = torch.topk(final_average_probs, k=10)
            json_object["ensemble_rel_values"] = relative_representation_values.tolist()[0]
            json_object["ensemble_rel_indices"] = relative_representation_indices.tolist()[0]

            torch.set_grad_enabled(True)
            main_model_generate_ids_logits = main_model_generate_ids_logits.to(self.device_compute).detach().clone().to(
                torch.float32)
            main_model_generate_ids_logits.requires_grad_(True)
            local_main_model_relative_representation_matrix = local_main_model_relative_representation_matrix.to(
                self.device_compute)
            local_learning_rate = self.learning_rate
            criterion = nn.KLDivLoss()

            main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
            # main_model_generate_ids_probs.is_leaf = True
            main_model_generate_ids_probs = main_model_generate_ids_probs.detach().clone()
            main_model_generate_ids_probs.requires_grad = True

            optimizer = torch.optim.AdamW(params=[main_model_generate_ids_probs],
                                          lr=local_learning_rate,
                                          betas=(0.9, 0.999))

            scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=local_learning_rate / 4)
            for i in range(1, self.learning_epochs_nums + 1):
                main_model_generate_ids_probs_t = main_model_generate_ids_probs - main_model_generate_ids_probs.min()
                main_model_generate_ids_probs_t /= main_model_generate_ids_probs_t.sum()
                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs_t,
                                                                    local_main_model_relative_representation_matrix)

                log_main_probs = torch.log(main_model_relative_representation_probs)
                loss = criterion(log_main_probs, final_average_probs)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()

            torch.set_grad_enabled(False)

            # main_model_generate_ids_probs =

            # main_model_generate_ids_logits = main_model_generate_ids_logits
            ensemble_result_values, ensemble_result_indices = torch.topk(main_model_generate_ids_probs, k=10)
            json_object["ensemble_result_values"] = ensemble_result_values.tolist()[0]
            json_object["ensemble_result_indices"] = ensemble_result_indices.tolist()[0]

            # ensemble_result_prob_values, ensemble_result_prob_indices = torch.topk(
            #     torch.nn.functional.softmax(main_model_generate_ids_probs * 10, dim=-1), k=10)
            print(main_model_generate_ids_probs.sum())
            main_model_generate_ids_probs -= main_model_generate_ids_probs.min()
            main_model_generate_ids_probs /= main_model_generate_ids_probs.sum()

            ensemble_result_prob_values, ensemble_result_prob_indices = torch.topk(main_model_generate_ids_probs, k=10)
            json_object["ensemble_result_prob_values"] = ensemble_result_prob_values.tolist()[0]
            json_object["ensemble_result_prob_indices"] = ensemble_result_prob_indices.tolist()[0]

            with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                process_file.write(json.dumps(json_object, ensure_ascii=False) + '\n')

            next_tokens_id = torch.argmax(main_model_generate_ids_logits, dim=-1)
            self.ensemble_model_output_ids_queue.put(next_tokens_id)
            return main_model_generate_ids_logits.to(self.device).detach()

        else:
            with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                process_file.write(json.dumps(json_object, ensure_ascii=False) + '\n')
            next_tokens_id = torch.argmax(scores, dim=-1)
            self.ensemble_model_output_ids_queue.put(next_tokens_id)
            return scores
