import argparse
import os
import pdb
import sys
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM

root_dir = '/data/home/username/Experiments/LLM_ensemble'
sys.path.insert(0, root_dir)

from src.nllb.Model_generator import nllb_translate

if __name__ == '__main__':
    device = "cuda:0"

    parser = argparse.ArgumentParser(description='Process some files.')
    
    parser.add_argument('--src_lang', default="eng_Latn", type=str, help='the name of the file to process')
    parser.add_argument('--tgt_lang', default="ron_Latn", type=str, help='the name of the file to process')
    parser.add_argument('--mode', default="dev", type=str, help='the name of the file to process')
    args = parser.parse_args()
    print(args)
    src_lang = args.src_lang
    tgt_lang = args.tgt_lang

    mode = args.mode

    NLLB_model_path = "/data/home/username/ModelsHub/facebook/nllb-200-distilled-600M"
    NLLB_tokenizer = AutoTokenizer.from_pretrained(NLLB_model_path, src_lang=src_lang)
    NLLB_model = AutoModelForSeq2SeqLM.from_pretrained(NLLB_model_path).to(device)
    NLLB_model.eval()

    if mode == "dev":
        input_file_path = f"/data/home/username/Experiments/LLM_ensemble/Datasets/Flores/{mode}/{src_lang}.dev"
    else:
        input_file_path = f"/data/home/username/Experiments/LLM_ensemble/Datasets/Flores/dev{mode}/{src_lang}.devtest"

    output_file_path = f"/data/home/username/Experiments/LLM_ensemble/Eval/Flores-{src_lang}-{tgt_lang}/v4-NLLB-200-distilled-600M-{src_lang}-{tgt_lang}-{mode}/{tgt_lang}_0.0.txt"
    if not os.path.exists(os.path.dirname(output_file_path)):
        os.makedirs(os.path.dirname(output_file_path))

    with open(input_file_path, 'r', encoding="utf-8") as src_file:
        src_contents = src_file.readlines()

        for line in tqdm(src_contents):
            nllb_input_text = line.strip()

            result = nllb_translate(NLLB_model, NLLB_tokenizer, nllb_input_text, max_length=200, tgt_lang=tgt_lang,
                                    device=device)

            with open(output_file_path, "a+", encoding="utf-8") as f_result:
                f_result.write(result + "\n")
