import json
from collections import Counter, defaultdict
from pathlib import Path

from datasets import load_dataset
from fire import Fire
from IPython import embed
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

WIKI_DATASET = ("wikipedia", "20220301.en")


def count_unigram(
    model_name: str = "sentence-transformers/average_word_embeddings_glove.840B.300d",
    dataset_name: str = WIKI_DATASET,
):
    # might need dataset/model specific processing
    dataset = load_dataset(*dataset_name)
    model = SentenceTransformer(model_name)
    tokenizer = model.tokenizer
    tokenizer.stop_words = {}  # delete stop words

    unigram_counter = Counter()
    for split in dataset.keys():
        for example in tqdm(dataset[split]):
            text = example["text"]
            tokens = tokenizer.tokenize(text)
            unigram_counter.update(tokens)

    # save the unigram counter as json
    unigram_counter = dict(unigram_counter)
    unigram_counter = dict(sorted(unigram_counter.items()))
    save_path = Path(
        f"data/wikipedia/{Path(model_name).name}/raw_frequency.json"
    )  # change name
    save_path.parent.mkdir(parents=True, exist_ok=True)
    with save_path.open("w") as f:
        json.dump(unigram_counter, f, indent=2)

    # convert to probability
    total_count = sum(unigram_counter.values())
    unigram_prob = {k: v / total_count for k, v in unigram_counter.items()}
    save_path = Path(
        f"data/wikipedia/{Path(model_name).name}/unigram_prob.json"
    )  # change name
    save_path.parent.mkdir(parents=True, exist_ok=True)
    with save_path.open("w") as f:
        json.dump(unigram_prob, f, indent=2)


if __name__ == "__main__":
    Fire(count_unigram)
