
import json
import os.path
import re
import sys
import os
from comet import download_model, load_from_checkpoint
from tqdm import tqdm

# model_path = download_model("Unbabel/XCOMET-XL")
model_path = '/home/username/ModelsHub/Unbabel/wmt22-comet-da/checkpoints/model.ckpt'
model_path = '/ssd3/home/username/ModelsHub/Unbabel/wmt22-comet-da/checkpoints/model.ckpt'

# Load the model checkpoint:
model = load_from_checkpoint(model_path)


def get_comet_score(src_file_content, sys_file_content, ref_file_content):
    datas = []
    for src_line, sys_line, ref_line in zip(src_file_content, sys_file_content, ref_file_content):
        d1 = {}
        d1["src"] = src_line.strip()
        d1["mt"] = sys_line.strip()
        d1["ref"] = ref_line.strip()
        datas.append(d1)  # 将dict添加到list
    print(len(datas))
    model_output = model.predict(datas, batch_size=10, gpus=1)
    return model_output


def result_write(result_path, sys_file_name, comet_score):
    with open(os.path.join(result_path, 'comet_score.jsonl'), 'a+', encoding='utf-8') as result_file:
        dict = {}
        dict['sys_file_path'] = os.path.join(result_path, sys_file_name)
        match = re.search(r'lr(.*?)anchor_point_count(.*?)learning_epochs_nums(.*)', sys_file_name)
        lr, anchor_point_count, learning_epochs_nums = match.groups()
        dict['learning_rate'] = lr.strip('_')
        dict['anchor_point_count'] = anchor_point_count.strip('_')
        dict['learning_epochs_nums'] = learning_epochs_nums.strip('.jsonl')
        dict['count'] = len(comet_score.scores)
        dict['comet_system_score'] = comet_score.system_score * 100
        result_file.write(json.dumps(dict, ensure_ascii=False) + '\n')


def find_files_with_suffix(folder_path, suffix):
    # 使用os模块获取文件夹中所有文件的路径
    all_files = os.listdir(folder_path)
    # 筛选以指定后缀名结尾的文件
    filtered_files = [file for file in all_files if file.endswith(suffix)]
    return filtered_files


file_home_dir = sys.argv[1]

jsonl_files_list = find_files_with_suffix(file_home_dir, ".jsonl")
# print(pdf_files)
for jsonl_file in tqdm(jsonl_files_list):

    sys_file_name = jsonl_file
    print(jsonl_file)
    src_file_contents = []
    sys_file_contents = []
    ref_file_contents = []
    with open(os.path.join(file_home_dir, jsonl_file), 'r', encoding='utf-8') as f:
        contents = f.readlines()
        for line in contents:
            json_obj = json.loads(line)
            json_obj['question'] = json_obj['question'].strip()
            json_obj['prediction'] = json_obj['prediction'].strip()
            json_obj['answer'] = json_obj['answer'].strip()
            src_file_contents.append(json_obj['question'])
            sys_file_contents.append(json_obj['prediction'])
            ref_file_contents.append(json_obj['answer'])

    comet_score = get_comet_score(src_file_contents, sys_file_contents, ref_file_contents)
    print(comet_score.system_score)

    result_write(file_home_dir, sys_file_name, comet_score)
