"""
Contains evaluation utilities for pytorch-based rewriting methods.
To use, simply call `compute_rewrite_quality_zsre` with the
appropriate arguments, which returns a dictionary containing them.
"""

import typing

from transformers import AutoTokenizer
from ..util import HyperParams
from .evaluate_utils import (
    test_prediction_acc,
    test_generation_quality,
    test_seq2seq_batch_prediction_acc,
    PPL,
    OOD_PPL,
    kl_loc_loss,
    F1
)


def compute_portability_quality(
    model,
    model_name,
    hparams: HyperParams,
    tok: AutoTokenizer,
    portability_key: str,
    prompt: str,
    ground_truth: str,
    device,
) -> typing.Dict:

    if 't5' in model_name.lower():
        portability_correct = test_seq2seq_batch_prediction_acc(model, tok, hparams, prompt, ground_truth, device)
    else:
        portability_correct = test_prediction_acc(model, tok, hparams, prompt, ground_truth, device, vanilla_generation=hparams.alg_name in ['GRACE', 'Defer'])

    ret = {
        f"{portability_key}_acc": portability_correct
    }
    return ret

def compute_rewrite_or_rephrase_quality(
    model,
    hparams: HyperParams,
    tok: AutoTokenizer,
    prompt: str,
    target_new: str,
    device,
    test_rephrase: bool = False,
    test_ood_rephrase: bool = False,
    eval_metric: str = 'token_em',
) -> typing.Dict:
    model_name = hparams.model_name
    if test_ood_rephrase:
        key = 'ood_generality'
        assert eval_metric == 'ppl', print('ppl evaluation in ood rephrase...')
    elif not test_rephrase:
        key = 'rewrite'
    else:
        key = 'rephrase'
    if eval_metric == 'ppl':
        func = OOD_PPL if test_ood_rephrase else PPL
        ppl = func(model, tok, prompt, target_new, device, threshold=hparams.threshold)
        ret = {
            f"{key}_{'threshold_succ' if test_ood_rephrase else 'ppl'}": ppl
        }
    elif hparams.alg_name in ["GRACE"]:
        # ppl = PPL(model, tok, prompt, target_new, device)
        if 't5' in model_name.lower():
            acc = test_seq2seq_batch_prediction_acc(model, tok, hparams, prompt, target_new, device)
        else:
            acc = test_prediction_acc(model, tok, hparams, prompt, target_new, device, vanilla_generation=True)
        # f1 = F1(model,tok,hparams,prompt,target_new,device, vanilla_generation=True)
        ret = {
            f"{key}_acc": acc,
            # f"{key}_PPL": ppl,
            # f"{key}_F1":f1
        }        
    else:
        if 't5' in model_name.lower():
            acc = test_seq2seq_batch_prediction_acc(model, tok, hparams, prompt, target_new, device)
        else:
            acc = test_prediction_acc(model, tok, hparams, prompt, target_new, device)
        ret = {
            f"{key}_acc": acc
        }
    return ret

def compute_locality_quality(
    model,
    hparams: HyperParams,
    tok: AutoTokenizer,
    locality_key: str,
    prompt: str,
    locality_ground_truth: str,
    device,
) -> typing.Dict:
    model_name = hparams.model_name
    if 't5' in model_name.lower():
        loc_tokens = test_seq2seq_batch_prediction_acc(model, tok, hparams, prompt, locality_ground_truth, device, locality=True)
    else:
        loc_tokens = test_prediction_acc(model, tok, hparams, prompt, locality_ground_truth, device, locality=True, vanilla_generation=hparams.alg_name in ['GRACE', 'Defer'])

    if type(loc_tokens) is not list:
        loc_tokens = [loc_tokens,]

    ret = {
        f"{locality_key}_output": loc_tokens
    }
    return ret