import os
import re
import random
import argparse

import torch
import torch.distributed
import datasets

from transformers import LlamaTokenizer
from tqdm import tqdm
from fastchat.model import get_conversation_template

from my_configuration_llama import LlamaConfig

from train_pose import smart_tokenizer_and_embedding_resize, DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN, DEFAULT_PAD_TOKEN, DEFAULT_UNK_TOKEN

gpu_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def truncate(tokenizer, prompt, max_length):
    split_data = prompt.split('\n')

    pattern = r'in line (\w+-\w+)\?'
    match = re.search(pattern, split_data[-1])
    if match:
        extracted_string = match.group(1)
    else:
        print(prompt)
        return prompt

    input = tokenizer(prompt, return_tensors="pt")
    prompt_length = input.input_ids.shape[-1]

    while prompt_length > max_length:
        
        while True:
            ids = random.sample(range(1, len(split_data) - 1), 1)[0]
            if extracted_string in split_data[ids]:
                continue
            del split_data[ids]
            break
        
        input = tokenizer('\n'.join(split_data), return_tensors="pt")
        prompt_length = input.input_ids.shape[-1]

    return '\n'.join(split_data)


def test_lines_one_sample(model, tokenizer, test_case):
    prompt = test_case["prompt"]
    expected_number = test_case["expected_number"]

    conv = get_conversation_template("vicuna")
    print(f"Using conversation template: {conv.name}")

    conv.append_message(conv.roles[0], prompt)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()

    input = tokenizer(prompt, return_tensors="pt")
    prompt_length = input.input_ids.shape[-1]

    output = model.generate(input_ids=input.input_ids.to(model.device), min_new_tokens=5, max_new_tokens=35, use_cache=True)[0]
    output = output[prompt_length:]
    output = tokenizer.batch_decode([output], skip_special_tokens=True)[0]

    # Matching the first digit of the model output.
    response_number = re.findall("\d+", output)
    if response_number is not None and len(response_number) > 0:
        response_number = int(response_number[0])
    else:
        print(f"Got unparsable result")
        response_number = -1

    summary = f"Label: {expected_number}, Predict: {output}, Parsed: {response_number}, prompt length: {prompt_length}".replace('\n', ' ')
    print(summary)

    return expected_number == response_number, prompt_length, summary


def main():

    # add parser
    parser = argparse.ArgumentParser()
    parser.add_argument("--eval_shortest_only", action="store_true", default=False)
    parser.add_argument("--model_max_position_embeddings", type=int, default=2048)
    parser.add_argument("--rope_scaling_factor", type=float, default=1.0)
    parser.add_argument("--rope_scaling_type", type=str, default=None)
    parser.add_argument("--model_name", type=str, default="llama-7b")
    parser.add_argument("--output_dir", type=str, default=None)
    parser.add_argument("--path_to_ckp", type=str, default=None)
    parser.add_argument("--use_flash_attn", type=int, default=1)
    args = parser.parse_args()

    model_name_or_path = args.path_to_ckp

    Config, CausalLM, Tokenizer = None, None, None

    if args.use_flash_attn:
        from my_flash_modeling_llama import LlamaForCausalLM
    else:
        from my_modeling_llama import LlamaForCausalLM

    Config, CausalLM, Tokenizer = LlamaConfig, LlamaForCausalLM, LlamaTokenizer

    config = Config.from_pretrained(model_name_or_path)
    scaled_max_position_embeddings=int(args.model_max_position_embeddings * args.rope_scaling_factor)

    if config.rope_scaling is None:
        if args.rope_scaling_type is not None:
            config.rope_scaling={"type": args.rope_scaling_type, "factor": args.rope_scaling_factor}
            config.max_position_embeddings=scaled_max_position_embeddings
            if args.rope_scaling_type == "yarn":
                config.rope_scaling["original_max_position_embeddings"] = args.model_max_position_embeddings
            
    config.use_cache=True
    
    print(f'load model from {model_name_or_path}')
    model = CausalLM.from_pretrained(pretrained_model_name_or_path=model_name_or_path, config=config,torch_dtype=torch.float16)
    model.to(gpu_device)
    model.eval()

    print('load tokenizer')
    tokenizer = Tokenizer.from_pretrained(model_name_or_path, use_fast=False if "baichuan" in args.model_name else True)
    
    if not "baichuan" in args.model_name:
        smart_tokenizer_and_embedding_resize(
            special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
            tokenizer=tokenizer,
            model=model,
        )
    if "llama" in args.model_name:
        smart_tokenizer_and_embedding_resize(
            special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
            tokenizer=tokenizer,
            model=model,
        )
        tokenizer.add_special_tokens(
            {
                "eos_token": DEFAULT_EOS_TOKEN,
                "bos_token": DEFAULT_BOS_TOKEN,
                "unk_token": DEFAULT_UNK_TOKEN,
            }
        )

    lines_dataset = datasets.load_from_disk('/scratch/nlp/wutong/dataset/PoSE-Datasets/LongChat-Lines')
    lines = list(lines_dataset.keys())

    if args.eval_shortest_only:
        lines = [min(lines)]

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    output_file = os.path.join(args.output_dir, f"response_{args.model_name}.txt")

    if config.rope_scaling is None:
        lines = lines[:5]

    for num_lines in lines:
        print(f"************ Start testing {num_lines} lines per LRT prompt ************")

        num_correct = 0
        avg_length = 0

        test_cases = lines_dataset[num_lines]
        for test_case in tqdm(test_cases):
            correct, prompt_length, _ = test_lines_one_sample(model=model, tokenizer=tokenizer, test_case=test_case)
            avg_length += prompt_length / len(test_cases)
            num_correct += correct
        accuracy = num_correct / len(test_cases)

        with open(output_file, "a+") as f:
            f.write(f"************ Finish testing {num_lines} lines per prompt with average prompt length {avg_length}, accuracy: {accuracy} ************\n\n")

        print(f"************ Finish testing {num_lines} lines per prompt with average prompt length {avg_length}, accuracy: {accuracy} ************")
        if args.eval_shortest_only:
            break


if __name__ == "__main__":
    main()
