import logging
from typing import Dict, Optional

import transformers

from rtfm import special_tokens as tok


def sanity_check_tokenizer(tokenizer, model_name):
    logging.warning("sanity checking the tokenizer special tokens are in vocab...")
    if (
        "llama" in model_name.lower()
        and "2" in model_name
        and len(tokenizer.vocab) < 128_254
    ):
        eoc_token_id_expected = 32000
        qa_token_id_expected = 32001
        choices_sep_token_expected = 8876  # this token is already in llama3 vocab

    elif (
        "llama" in model_name.lower() and "3" in model_name and len(tokenizer) > 128254
    ):
        eoc_token_id_expected = 128256
        qa_token_id_expected = 128257
        choices_sep_token_expected = 8651  # this token is already in llama3 vocab
    else:
        raise ValueError(f"unknown model name: {model_name}")

    assert tokenizer(tok.EOC_TOKEN, add_special_tokens=False)["input_ids"] == [
        eoc_token_id_expected
    ], f"EOC token tokenizes to {tokenizer(tok.EOC_TOKEN, add_special_tokens=False)['input_ids']}"
    assert tokenizer(tok.QA_SEP_TOKEN, add_special_tokens=False)["input_ids"] == [
        qa_token_id_expected
    ], f"QA_SEP token tokenizes to {tokenizer(tok.QA_SEP_TOKEN, add_special_tokens=False)['input_ids']}"

    assert tokenizer(tok.ANS_CHOICES_SEP_TOKEN, add_special_tokens=False)[
        "input_ids"
    ] == [
        choices_sep_token_expected
    ], f"ANS_CHOICES_SEP_TOKEN token tokenizes to {tokenizer(tok.ANS_CHOICES_SEP_TOKEN, add_special_tokens=False)['input_ids']}"
    logging.warning("tokenizer sanity check passed!")


def smart_tokenizer_and_embedding_resize(
    special_tokens_dict: Dict[str, str],
    tokenizer: transformers.PreTrainedTokenizer,
    model: transformers.PreTrainedModel,
    other_tokens_dict: Optional[Dict[str, str]] = None,
):
    """Resize tokenizer and embedding matrix, adding both special_tokens_dict and other_tokens_dict.

    :param special_tokens_dict: special tokens that can be added with tokenizer.add_special_tokens().
        Typically this only includes tokens like bos_token, eos_token, pad_token.
        See transformers.tokenization_utils method .add_special_tokens() for more info.
    :param other_tokens_dict: tokens that cannot be added with tokenizer.add_special_tokens().
        This is where most tokens should be added.
    :param tokenizer: the tokenizer to modify.
    :param model: the model to be used with the tokenizer; its embedding matrix will be resized accordinly.

    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    logging.debug(f"len(tokenizer) before resize is {len(tokenizer)}")
    logging.warning(f"adding special tokens {special_tokens_dict} to vocab")
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    if other_tokens_dict:
        logging.warning(f"adding tokens {other_tokens_dict} to vocab")
        num_new_tokens += tokenizer.add_tokens(
            list(other_tokens_dict.values()), special_tokens=True
        )
    logging.info(f"adding {num_new_tokens} to vocab")
    model.resize_token_embeddings(len(tokenizer))

    if num_new_tokens > 0:
        input_embeddings = model.get_input_embeddings().weight.data
        output_embeddings = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True
        )
        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True
        )

        input_embeddings[-num_new_tokens:] = input_embeddings_avg
        output_embeddings[-num_new_tokens:] = output_embeddings_avg
