"""
Perform TF-IDF on Alpaca+ training set.
"""

import json

import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer

with open("alpaca_plus/alpaca_plus_train.json") as f:
    examples = json.load(f)

instructions = [example["instruction"] for example in examples]

vec = TfidfVectorizer(stop_words='english')

X = vec.fit_transform(instructions)


df = pd.DataFrame(X.toarray(), columns=vec.get_feature_names_out())
# scores = Counter(zip([x[0] for x in vec.get_feature_names_out()], X.toarray()[0]))
# print(scores.most_common(5))

# Record training set words
for i, example in enumerate(examples):
    best_word = df.columns[df.loc[i].argmax()]
    example["tf_idf_word"] = best_word
    example["first_word"] = example["instruction"].split()[0]

with open("alpaca_plus/alpaca_plus_train.json", "w") as f:
    json.dump(examples, f)

# Do the same for the validation sets.
for split in {"human", "seen", "unseen"}:
    validation_file = f"alpaca_plus/alpaca_plus_validation_{split}.json"
    with open(validation_file, "r") as f:
        examples = json.load(f)
    instructions = [example["instruction"] for example in examples]
    X = vec.transform(instructions)
    df = pd.DataFrame(X.toarray(), columns=vec.get_feature_names_out())
    for i, example in enumerate(examples):
        best_word = df.columns[df.loc[i].argmax()]
        example["tf_idf_word"] = best_word
        example["first_word"] = example["instruction"].split()[0]
    
    with open(validation_file, "w") as f:
        json.dump(examples, f)