import random
import nltk
import re
import json

from tqdm import tqdm


line_num = 1330


# nltk.download('words')
# nltk.download('averaged_perceptron_tagger')

word_list = nltk.corpus.words.words()


from transformers import LlamaTokenizer

DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "</s>"

tokenizer = LlamaTokenizer.from_pretrained('/scratch2/nlp/plm/Llama-2-7b-hf', use_fast=True)

if tokenizer.pad_token is None:
    tokenizer.add_special_tokens(
        {
            "eos_token": DEFAULT_EOS_TOKEN,
            "bos_token": DEFAULT_BOS_TOKEN,
            "unk_token": DEFAULT_UNK_TOKEN,
        }
    )
tokenizer.padding_side = "left"

total_data = []
avg_len = 0
for _ in tqdm(range(50)):
    prompt = "Below is a record of lines I want you to remember. Each line begins with 'line <line index>' and contains a '<REGISTER_CONTENT>' at the end of the line as a numerical value. For each line index, memorize its corresponding <REGISTER_CONTENT>. At the end of the record, I will ask you to retrieve the corresponding <REGISTER_CONTENT> of a certain line index. Now the record start:\n\n"

    out_word_list, out_num_list = [], []
    for _ in range(line_num):

        rand_num = random.randint(10000, 100000)
        while rand_num in out_num_list:
            rand_num = random.randint(10000, 100000)
        out_num_list.append(rand_num)
        
        word_a = random.choice(word_list).lower()
        while word_a in out_word_list or len(word_a) > 7:
            word_a = random.choice(word_list).lower()
        word_b = random.choice(word_list).lower()
        while word_a == word_b or word_b in out_word_list or len(word_b) > 7:
            word_b = random.choice(word_list).lower()
        out_word_list.extend([word_a, word_b])

        word_pair = f"{word_a}-{word_b}"

        prompt += f"line {word_pair}: REGISTER_CONTENT is <{rand_num}>\n"

    ground_num = random.sample(out_num_list, 1)[0]

    def get_pair(ground_num, prompt):
        pattern = fr"line (.+): REGISTER_CONTENT is <{ground_num}>"

        match = re.search(pattern, prompt)
        return match.group(1)

    assert len(out_num_list) == line_num and len(out_word_list) == line_num * 2

    ground_pair = get_pair(ground_num, prompt)
    prompt += f"\nNow the record is over. Tell me what is the <REGISTER_CONTENT> in line {ground_pair}? I need the number.\n"

    token_len = len(tokenizer(prompt).input_ids)
    avg_len += token_len

    total_data.append({
        'expected_number': ground_num,
        'num_lines': line_num,
        'token_size': token_len,
        'prompt': prompt,
    })

print('AVG Len: ', avg_len / 50)
with open(f'src/{line_num}.json', 'w') as wp:
    json.dump(total_data, wp)


