import os
import pdb
import sys
import time
import json
import torch
import queue
import logging
from tqdm import tqdm

root_dir = '/data/home/username/Experiments/LLM_ensemble'
sys.path.insert(0, root_dir)
from src.instruction_generate import demon_prompt_generate, task_instruction_generate



import argparse
from src.main_model_thread import MainModelThread
from src.model_load import load_model
from src.assist_model_thread import AssistModelThread
from src.common_vocabulary import CommonVocabulary
from src.transfer_matrix import ProbabilityTransferMatrix


def main():
    start_time = time.time()  

    
    parser = argparse.ArgumentParser(description='Process some files.')
    
    parser.add_argument('--config', help='the name of the file to process')
    parser.add_argument('--learning_rate', '-lr', default=0.0, type=float, required=False, help="learning_rate")
    parser.add_argument('--anchor_point_count', '-apc', default=32000, type=int, required=False,
                        help='anchor_point_count')
    parser.add_argument('--learning_epochs_nums', '-len', default=5, type=int, required=False,
                        help='learning_epochs_nums')
    parser.add_argument('--result_save_dir', '-rsd', default="./", type=str, required=False, help='result_save_dir')
    parser.add_argument('--run_mode', '-rm', default="dev", type=str, required=False, help='result_save_dir')
    parser.add_argument('--logits_processor_mode', '-lpm', default="based_on_probility_transfer_logits_local_processor",
                        type=str,
                        required=False,
                        help='logits_processor_mode')
    parser.add_argument('--device_compute', '-dp', default="cuda:1", type=str, required=False,
                        help='device_compute')
    parser.add_argument('--device0', '-d0', default="auto", type=str, required=False,
                        help='device0')
    parser.add_argument('--device1', '-d1', default="auto", type=str, required=False,
                        help='device1')
    parser.add_argument('--device2', '-d2', default="auto", type=str, required=False,
                        help='device2')
    parser.add_argument('--device3', '-d3', default="auto", type=str, required=False,
                        help='device3')

    parser.add_argument('--main_temperature', '-mt', default=100, type=float, required=False,
                        help='main_temperature')
    parser.add_argument('--assist_temperature', '-at', default=100, type=float, required=False,
                        help='assist_temperature')
    parser.add_argument('--min_prob', default=0.8, type=float, required=False,
                        help='min_prob')
    parser.add_argument('--max_prob', default=0.9, type=float, required=False,
                        help='max_prob')

    # 解析命令行参数
    args = parser.parse_args()

    # 使用指定的文件名来操作文件
    with open(args.config, 'r', encoding='utf-8') as f:
        config_json = json.load(f)

    main_model_path = config_json["model_path"]["main_model_path"]
    assist_model1_path = config_json["model_path"]["assist_model1_path"]
    assist_model2_path = config_json["model_path"]["assist_model2_path"]
    assist_model3_path = config_json["model_path"]["assist_model3_path"]
    main_model_probability_transfer_matrix_path = config_json["probability_transfer_matrix_path"]["main_model_path"]
    assist_model1_probability_transfer_matrix_path = config_json["probability_transfer_matrix_path"][
        "assist_model1_path"]
    assist_model2_probability_transfer_matrix_path = config_json["probability_transfer_matrix_path"][
        "assist_model2_path"]
    assist_model3_probability_transfer_matrix_path = config_json["probability_transfer_matrix_path"][
        "assist_model3_path"]

    dev_file_path = config_json["file_path"]["dev_file_path"]
    test_file_path = config_json["file_path"]["test_file_path"]

    demon_file_path = config_json["file_path"]["demon_file_path"]

    instruction = config_json["prompt_template"]["instruction"]
    instruction_parameter = config_json["prompt_template"]["instruction_parameter"]
    main_model_system_template = config_json["prompt_template"]["main_model_system_template"]
    assist_model1_system_template = config_json["prompt_template"]["assist_model1_system_template"]
    assist_model2_system_template = config_json["prompt_template"]["assist_model2_system_template"]
    assist_model3_system_template = config_json["prompt_template"]["assist_model3_system_template"]
    max_new_tokens = config_json["run_parameter"]["max_new_tokens"]
    # start_index = config_json["run_parameter"]["start_index"]
    # end_index = config_json["run_parameter"]["end_index"]
    try:
        end_token_id = config_json["run_parameter"]["end_token_id"]
    except:
        end_token_id = 2

    demon_parameter = config_json["prompt_template"]["demon_parameter"]

    result_process_parameter = config_json["result_process_parameter"]
    try:
        early_stop_string_list = result_process_parameter["early_stop_string_list"]
    except:
        early_stop_string_list = None
    result_save_dir = args.result_save_dir
    logits_processor_mode = args.logits_processor_mode
    if os.path.isdir(result_save_dir):
        pass
    else:
        os.makedirs(result_save_dir)

    anchor_point_count = args.anchor_point_count
    learning_rate = args.learning_rate
    learning_epochs_nums = args.learning_epochs_nums
    run_mode = args.run_mode

    device_compute = args.device_compute
    device0 = args.device0
    device1 = args.device1
    device2 = args.device2
    device3 = args.device3

    main_temperature = args.main_temperature
    assist_temperature = args.assist_temperature

    input_file_path = dev_file_path if run_mode == "dev" else test_file_path

    logging.basicConfig(filename=os.path.join(result_save_dir,
                                              f'ensemble_lr{learning_rate}_anchor_point_count{anchor_point_count}_learning_epochs_nums{learning_epochs_nums}.process.log'),
                        level=logging.DEBUG)
    logging.info(f'\n【config_json:】{config_json}')
    logging.info(f'\n【result_save_dir:】{result_save_dir}')
    logging.info(f'\n【anchor_point_count:】{anchor_point_count}')
    logging.info(f'\n【learning_rate:】{learning_rate}')
    logging.info(f'\n【learning_epochs_nums:】{learning_epochs_nums}')
    # pdb.set_trace()
    # pdb.set_trace()
    print("loading probability transfer matrix")
    main_model_probability_transfer_matrix = torch.load(main_model_probability_transfer_matrix_path,
                                                        map_location=device0)
    assist_model_probability_transfer_matrix1 = torch.load(assist_model1_probability_transfer_matrix_path,
                                                           map_location=device1)
    assist_model_probability_transfer_matrix2 = torch.load(assist_model2_probability_transfer_matrix_path,
                                                           map_location=device2)
    assist_model_probability_transfer_matrix3 = torch.load(assist_model3_probability_transfer_matrix_path,
                                                           map_location=device3)

    main_model, main_model_tokenizer, main_model_streamer = load_model(main_model_path, "auto")

    assist_model1, assist_model_tokenizer1, _ = load_model(assist_model1_path, "auto")
    assist_model2, assist_model_tokenizer2, _ = load_model(assist_model2_path, "auto")
    assist_model3, assist_model_tokenizer3, _ = load_model(assist_model3_path, "auto")
    assist_model_tokenizer_list = [assist_model_tokenizer1, assist_model_tokenizer2, assist_model_tokenizer3]
    # pdb.set_trace()
    # common_vocabulary = CommonVocabulary(main_model_tokenizer, assist_model_tokenizer1)
    common_vocabulary = CommonVocabulary(main_model_tokenizer, assist_model_tokenizer1, assist_model_tokenizer2,
                                         assist_model_tokenizer3)
    #
    common_vocab_list = common_vocabulary.get_common_vocab_list(*common_vocabulary.vocabs)
    #
    # print("common_vocab_list:{}".format(len(common_vocab_list)))
    # pdb.set_trace()
    # with torch.no_grad():
    #     
    probability_transfer_matrix = ProbabilityTransferMatrix()
    anchor_point_list = probability_transfer_matrix.get_anchor_point_list(common_vocab_list=common_vocab_list)

    main_model_probability_transfer_matrix_list = [main_model_probability_transfer_matrix

                                                   ]
    assist_model_probability_transfer_matrix_list = [assist_model_probability_transfer_matrix1,
                                                     assist_model_probability_transfer_matrix2,
                                                     assist_model_probability_transfer_matrix3
                                                     ]
    # ================================================================
    result_file_path = os.path.join(result_save_dir,
                                    f'ensemble_lr{learning_rate}_anchor_point_count{anchor_point_count}_learning_epochs_nums{learning_epochs_nums}.jsonl')
    try:
        with open(result_file_path, 'r') as file:
            lines = file.readlines()
            line_count = len(lines)
        start_index = line_count
    except:
        start_index = 0
    with open(input_file_path, 'r', encoding='utf-8') as input_file:
        try:
            demon_instruction, demon_count = demon_prompt_generate(demon_file_path, demon_parameter)
        except:
            demon_instruction = ""
            demon_count = 0
        contents = input_file.readlines()

        for index, line in enumerate(tqdm(contents[start_index:])):
            line = json.loads(line)

            task_instruction = task_instruction_generate(line, instruction_parameter)
            final_input_prompt = instruction + demon_instruction + task_instruction
            main_model_input = main_model_system_template.format(final_input_prompt)

            information_key_list = demon_parameter['key']
            information_dict = {}
            for key in information_key_list:
                information_dict[key] = line[key]
            information_dict['main_model_input'] = main_model_input
            information_dict['demon_count'] = demon_count
            information_dict['task_instruction'] = task_instruction
            information_dict['max_new_tokens'] = max_new_tokens
            information_dict['result_process_parameter'] = result_process_parameter
            information_dict['logits_processor_mode'] = logits_processor_mode
            information_dict['anchor_point_list'] = anchor_point_list
            information_dict['forced_eos_token_id'] = end_token_id
            ensemble_model_output_ids_queue = queue.Queue()
            assist_model_score_queue_list = []
            assist_model_score_queue1 = queue.Queue()
            assist_model_score_queue2 = queue.Queue()
            assist_model_score_queue3 = queue.Queue()

            assist_model_score_queue_list.append(assist_model_score_queue1)
            assist_model_score_queue_list.append(assist_model_score_queue2)
            assist_model_score_queue_list.append(assist_model_score_queue3)

            main_model_thread = MainModelThread(main_model=main_model,
                                                main_model_tokenizer=main_model_tokenizer,
                                                assist_model_tokenizer=assist_model_tokenizer_list,
                                                information_dict=information_dict,
                                                learning_rate=learning_rate,
                                                anchor_point_count=anchor_point_count,
                                                learning_epochs_nums=learning_epochs_nums,
                                                result_save_dir=result_save_dir,
                                                ensemble_model_output_ids_queue=ensemble_model_output_ids_queue,
                                                assist_model_score_queue_list=assist_model_score_queue_list,
                                                main_model_probability_transfer_matrix_list=main_model_probability_transfer_matrix_list,
                                                assist_model_probability_transfer_matrix_list=assist_model_probability_transfer_matrix_list,
                                                device_compute=device_compute,
                                                device=device0,
                                                early_stop_string_list=early_stop_string_list
                                                )
            main_model_thread.start()
            block_flag = False
            temp_tensor = torch.tensor([], dtype=torch.int64).to(device_compute)

            assist_model_input1 = assist_model1_system_template.format(final_input_prompt)
            assist_model_input2 = assist_model2_system_template.format(final_input_prompt)
            assist_model_input3 = assist_model3_system_template.format(final_input_prompt)

            for i in range(max_new_tokens):
                
                if not block_flag:
                    assist_model_thread1 = AssistModelThread(model=assist_model1,
                                                             model_tokenizer=assist_model_tokenizer1,
                                                             assist_model_input=assist_model_input1,
                                                             assist_model_score_queue=assist_model_score_queue1,
                                                             device=device1,
                                                             result_save_dir=result_save_dir
                                                             )
                    assist_model_thread1.start()
                    assist_model_thread2 = AssistModelThread(model=assist_model2,
                                                             model_tokenizer=assist_model_tokenizer2,
                                                             assist_model_input=assist_model_input2,
                                                             assist_model_score_queue=assist_model_score_queue2,
                                                             device=device2,
                                                             result_save_dir=result_save_dir
                                                             )
                    assist_model_thread2.start()
                    assist_model_thread3 = AssistModelThread(model=assist_model3,
                                                             model_tokenizer=assist_model_tokenizer3,
                                                             assist_model_input=assist_model_input3,
                                                             assist_model_score_queue=assist_model_score_queue3,
                                                             device=device3,
                                                             result_save_dir=result_save_dir
                                                             )
                    assist_model_thread3.start()

                try:
                    ensemble_model_generate_next_id = ensemble_model_output_ids_queue.get(block=True,
                                                                                          timeout=4 + 0.0167 * max_new_tokens).to(
                        device_compute)
                    logging.info(f'{i}, {main_model_tokenizer.convert_ids_to_tokens(ensemble_model_generate_next_id)}')
                    print(i, main_model_tokenizer.convert_ids_to_tokens(ensemble_model_generate_next_id))
                except:
                    print("ending")
                    logging.info(f'\nending')
                    break
                    
                    
                    ensemble_model_generate_next_id = torch.tensor(
                        [main_model_tokenizer.convert_tokens_to_ids("</s>")])

                if block_flag or 130 <= ensemble_model_generate_next_id.tolist()[0] <= 258:
                    # pdb.set_trace()
                    temp_tensor = torch.cat([temp_tensor, ensemble_model_generate_next_id], dim=0)
                else:
                    temp_tensor = ensemble_model_generate_next_id

                got_tokens = main_model_tokenizer.convert_ids_to_tokens(temp_tensor)
                temp_tokens = got_tokens[:]
                string = main_model_tokenizer.convert_tokens_to_string(got_tokens)
                next_id = ensemble_model_generate_next_id.tolist()[0]
                if 130 <= next_id <= 258:
                    if "�" in string:
                        block_flag = True
                        
                        continue
                    else:
                        new_token = string

                elif next_id <= 130:
                    new_token = string
                else:

                    if isinstance(temp_tokens[0], bytes):
                        temp_tokens[0] = temp_tokens[0].decode("utf-8")

                    if temp_tokens[0].startswith('▁'):
                        new_token = " " + temp_tokens[0][1:]
                    else:
                        new_token = temp_tokens[0]
                if new_token == "</s>" or new_token == "<unk>":
                    break

                assist_model_input1 += "{}".format(new_token)
                assist_model_input2 += "{}".format(new_token)
                assist_model_input3 += "{}".format(new_token)

                temp_tensor = torch.tensor([], dtype=torch.int64).to(device_compute)
                block_flag = False

    time_elapsed = time.time() - start_time  # 获得时间差
    minutes = int(time_elapsed / 60)
    seconds = int(time_elapsed % 60)
    logging.info(f"\nTime taken: {minutes} min {seconds} sec")
    print('Time taken: {} min {} sec'.format(minutes, seconds))


if __name__ == '__main__':
    main()
