import os, sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
BASE_DIR = os.path.dirname(BASE_DIR)
sys.path.append(BASE_DIR)

import argparse
import pickle
from collections import defaultdict

import torch.types
from tqdm import tqdm

from utils.utils import set_seed, get_prompt_weave_v20_llama3b, get_prompt_2_weave_v20_llama3b, \
    get_prompt_3_weave_v20_llama3b, get_prompt_4_weave_v20_llama3b, get_prompt_5_weave_v20_llama3b, \
    get_prompt_6_weave_v20_llama3b
from datas.get_data import get_data
from torch.utils.data import DataLoader
from utils.utils import get_model
from utils.utils import get_promt
import torch.nn.functional as F
from utils.utils import compare_retrieval_acc

import numpy as np
import os

model_custom_config = {
    "max_new_tokens": 50,
    "temperature": 0.1,
    "top_p": 0.9
}

def main(args):

    if "weave-mpt1" == args.method:
        import methods.weave_mpt1 as weave_mpt
        weave_mpt.push_pos = args.push_mpt
        import models.mpt_7b.weave_attention as weave_attention
        weave_attention.chunk_width = 2047 # 512 # args.push_mpt

    elif "weave-mpt2" == args.method:
        import methods.weave_mpt2 as weave_mpt
        weave_mpt.push_pos = args.push_mpt
    elif "weave-mpt3" == args.method:
        import methods.weave_mpt3 as weave_mpt
        weave_mpt.push_pos = args.push_mpt
        weave_mpt.push_width = args.push_width
    elif "weave-mpt6" == args.method:
        import methods.weave_mpt6 as weave_mpt
        weave_mpt.push_pos = args.push_mpt
        weave_mpt.push_width = args.push_width
        weave_mpt.chunk_width = args.chunk_width
    elif "weave-mpt7" == args.method:
        import methods.weave_mpt7 as weave_mpt
        weave_mpt.push_pos = args.push_mpt
        weave_mpt.push_width = args.push_width
        weave_mpt.chunk_width = args.chunk_width



    dataset = get_data(args.dataset)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)

    if args.cuda == "auto":
        device = "auto"
    else:
        device = torch.device(int(args.cuda))
    tokenizer, model = get_model(args.model_path, device, method=args.method, args=args)

    prefix_prompt = get_promt(args.model_path)

    if "weave-v20-llama3b" in args.method or "vicuna" in args.model_path or "llama" in args.model_path:
        # prefix_prompt = get_prompt_3_weave_v20_llama3b()
        # prefix_prompt = get_prompt_weave_v20_llama3b()
        # prefix_prompt = get_prompt_4_weave_v20_llama3b()
        # prefix_prompt = get_prompt_5_weave_v20_llama3b()
        prefix_prompt = get_prompt_6_weave_v20_llama3b()

    all_length_acc = defaultdict(list)

    before_len = 0

    pbar = tqdm(dataloader)
    count = 0
    for data in pbar:
        # if count > 30:
        #     break
        count += 1

        # if count < 3400:
        #     continue

        model.eval()
        with torch.no_grad():
            # query = prefix_prompt + data["text"][0] + "\n\n"
            query = prefix_prompt.format(data["text"][0])
            inputs_token = tokenizer(query, return_tensors="pt").to(model.device)
            input_ids = inputs_token.input_ids
            print("input token length: {}".format(len(input_ids[0])))

            # if len(input_ids[0]) < 34000: #4096: #15785: #4096: # 14801:
            #     continue
            #
            # # before_len = len(input_ids[0])
            #
            # if len(input_ids[0]) > 30000: #4096: #15785: #4096: # 14801:
            #     break


            outputs = model.generate(input_ids, **model_custom_config)
            # outputs = model.generate(input_ids)

            response = tokenizer.decode(outputs[0])[len(query):]
            print("response: {}".format(response))
            print("target: {}".format(data["target"][0]))

            acc = compare_retrieval_acc(response, data["target"][0])

            if acc == 1:
                print("success")
            else:
                print("failed")

            token_length = int(data["token_length"][0])
            all_length_acc[token_length].append(acc)

        # raise NotImplementedError("Stop")

        all_mean_var_res = {
            token_length: {
                "mean": np.nanmean(np.array(record)),
                "var": np.nanvar(np.array(record))
            }
            for token_length, record in all_length_acc.items()
        }

        with open(f"{os.path.join(os.getcwd(), args.log_dir)}/{args.save_file}", "wb") as f:
            pickle.dump({"length_mean_var": all_mean_var_res}, f)

        with open(f"{os.path.join(os.getcwd(), args.log_dir)}/{'record_'+args.save_file}", "wb") as f:
            pickle.dump({"all_length_acc": all_length_acc}, f)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="/data/persist/models/mosaicml-mpt-7b")  # mosaicml-mpt-7b
    parser.add_argument("--method", type=str, default="weave-mpt7")
    parser.add_argument("--dataset", type=str, default="../datas/passkey-data_dup-10_answer-6bit.json")
    parser.add_argument("--save_file", type=str, default="retrieval_streaming-llm-llama3b-newprompt6_test.pkl")
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--log_dir", type=str, default="../logs")
    parser.add_argument("--cuda", type=str, default="1")
    parser.add_argument("--hard_cuda", type=int, default=0)
    parser.add_argument("--seed", type=int, default=0)

    # for mpt-alibi
    parser.add_argument("--push_mpt", type=int, default=512)
    parser.add_argument("--push_width", type=int, default=50)
    parser.add_argument("--chunk_width", type=int, default=512)

    args = parser.parse_args()
    set_seed(args.seed)
    main(args)

















