from data import CausalDataset, MoralDataset, Example, Annotation, Sentence, AbstractDataset, JsonSerializable, FactorUtils
from adapter import GPT3Adapter, HuggingfaceAdapter, DelphiAdapter
from evaluator import AccuracyEvaluator, AccuracyEvaluatorWithAmbiguity, CorrelationEvaluator, RMSEEvaluator, AuROCEvaluator
from prompt import JudgmentPrompt, AbstractJudgmentPrompt, CausalJudgmentPrompt, CausalAbstractJudgmentPrompt
from thought_as_text_translator import MoralTranslator, CausalTranslator, Translator
from expert_reasoning_engine import CausalReasoningEngine, MoralReasoningEngine
from tqdm import tqdm
import pickle

import json
from typing import List

import numpy as np

from dataclasses import dataclass

@dataclass
class ExperimentResult(JsonSerializable):
    acc: float
    conf_interval: tuple[float, float]
    r: float
    p: float
    rmse: float
    auroc: float


def override_examples_in_dataset(engine: str, cd: AbstractDataset, task: str='causal'):
    # load pickle
    with open(f"../../results/factor_preds/exp3_{engine}_{task}_factor_preds.pkl", "rb") as f:
        all_instances, all_choice_scores, all_label_indices = pickle.load(f)

    # override annotations in each example, using the preidcted choices
    ex: Example
    counter = 0
    for ex in tqdm(cd, total=len(cd)):
        original_num_sents = len(ex.annotated_sentences)
        sent: Sentence
        new_annotations: List[Sentence] = []
        for sent in ex.annotated_sentences:
            anno: Annotation = sent.annotation
            factor = anno.factor
            # create a new Sentence, with predicted tags
            choice_idx = np.argmax(all_choice_scores[counter])
            tag = all_instances[counter].choices[choice_idx]
            annotation = Annotation(factor, eval(f"FactorUtils.{factor}_answers_map_reverse")[tag.lower()])

            new_annotations.append(Sentence(sent.text, sent.victim, annotation))
            counter += 1

        ex.annotated_sentences = new_annotations

        assert original_num_sents == len(new_annotations)

def run_template_for_gpt3(cd: AbstractDataset, adapter: GPT3Adapter,
                          jp: JudgmentPrompt, ajp:AbstractJudgmentPrompt,
                          translator: Translator,
                          method: str='yesno'):
    all_choice_scores, all_label_dist = [], []
    ex: Example
    for ex in tqdm(cd):
        if len(ex.annotated_sentences) == 0:
            instance = jp.apply(ex)
        else:
            abs = translator.translate_example(ex)
            instance = ajp.apply(abs)

        choice_scores = adapter.adapt(instance, method=method)
        all_choice_scores.append(choice_scores)
        all_label_dist.append(ex.answer_dist)

    return all_choice_scores, all_label_dist

def exp4_causal_ttt_pred(engine='text-davinci-002'):
    # 1. load predicted factors from Exp 3
    # 2. Get factors out, replace the annotation with the predicted factors
    # 3. Use Exp 2 to predict abstract causal judgment

    cd = CausalDataset()
    evaluator = AccuracyEvaluatorWithAmbiguity()
    corr_evaluator = CorrelationEvaluator()
    rmse_evaluator = RMSEEvaluator()
    auroc_evaluator = AuROCEvaluator()

    adapter = GPT3Adapter(engine=engine)

    override_examples_in_dataset(engine, cd, task='causal')

    all_choice_scores, all_label_indices = [], []

    choice_scores, label_indices = run_template_for_gpt3(cd, adapter,
                                                         CausalJudgmentPrompt("./prompts/exp1_causal_prompt.jinja"),
                                                         CausalAbstractJudgmentPrompt(
                                                             "./prompts/exp1_causal_prompt.jinja"),
                                                         translator=CausalTranslator(),
                                                         method='yesno')
    all_choice_scores.extend(choice_scores)
    all_label_indices.extend(label_indices)

    choice_scores, label_indices = run_template_for_gpt3(cd, adapter,
                                                         CausalJudgmentPrompt("./prompts/exp1_causal_prompt_2.jinja"),
                                                         CausalAbstractJudgmentPrompt(
                                                             "./prompts/exp1_causal_prompt_2.jinja"),
                                                         translator=CausalTranslator(),
                                                         method='multiple_choice')
    all_choice_scores.extend(choice_scores)
    all_label_indices.extend(label_indices)

    acc, conf_interval = evaluator.evaluate(all_choice_scores, all_label_indices)
    r, p = corr_evaluator.evaluate(all_choice_scores, all_label_indices)
    rmse = rmse_evaluator.evaluate(all_choice_scores, all_label_indices)
    auroc = auroc_evaluator.evaluate(all_choice_scores, all_label_indices)

    print()
    print(f"engine: {engine}")
    print(f"Causal Abstract Accuracy: {acc:.4f} ({conf_interval[0]:.4f}, {conf_interval[1]:.4f})")
    print(f"Causal Correlation: {r:.4f} (p={p:.4f})")
    print(f"Causal RMSE: {rmse:.4f}")
    print(f"Causal AuROC: {auroc:.4f}")

    return ExperimentResult(acc, conf_interval, r, p, rmse, auroc)

def exp4_moral_ttt_pred(engine='text-davinci-002'):
    # 1. load predicted factors from Exp 3
    # 2. Get factors out, replace the annotation with the predicted factors
    # 3. Use Exp 2 to predict abstract causal judgment

    cd = MoralDataset()
    evaluator = AccuracyEvaluatorWithAmbiguity()
    corr_evaluator = CorrelationEvaluator()
    rmse_evaluator = RMSEEvaluator()
    auroc_evaluator = AuROCEvaluator()

    adapter = GPT3Adapter(engine=engine)

    override_examples_in_dataset(engine, cd, task='moral')

    all_choice_scores, all_label_indices = [], []

    choice_scores, label_indices = run_template_for_gpt3(cd, adapter,
                                                         JudgmentPrompt("./prompts/exp1_moral_prompt.jinja"),
                                                         AbstractJudgmentPrompt("./prompts/exp1_moral_prompt.jinja"),
                                                         translator=MoralTranslator(),
                                                         method='yesno')
    all_choice_scores.extend(choice_scores)
    all_label_indices.extend(label_indices)

    choice_scores, label_indices = run_template_for_gpt3(cd, adapter,
                                                         JudgmentPrompt("./prompts/exp1_moral_prompt_2.jinja"),
                                                         AbstractJudgmentPrompt("./prompts/exp1_moral_prompt_2.jinja"),
                                                         translator=MoralTranslator(),
                                                         method='multiple_choice')
    all_choice_scores.extend(choice_scores)
    all_label_indices.extend(label_indices)

    acc, conf_interval = evaluator.evaluate(all_choice_scores, all_label_indices)
    r, p = corr_evaluator.evaluate(all_choice_scores, all_label_indices)
    rmse = rmse_evaluator.evaluate(all_choice_scores, all_label_indices)
    auroc = auroc_evaluator.evaluate(all_choice_scores, all_label_indices)

    print()
    print(f"engine: {engine}")
    print(f"Moral Abstract Accuracy: {acc:.4f} ({conf_interval[0]:.4f}, {conf_interval[1]:.4f})")
    print(f"Moral Correlation: {r:.4f} (p={p:.4f})")
    print(f"Moral RMSE: {rmse:.4f}")
    print(f"Moral AuROC: {auroc:.4f}")

    return ExperimentResult(acc, conf_interval, r, p, rmse, auroc)

def exp4_causal_ere_pred(engine='text-davinci-002'):
    # 1. load predicted factors from Exp 3
    # 2. Get factors out, replace the annotation with the predicted factors
    # 3. Use Exp 2 to predict abstract causal judgment

    cd = CausalDataset()
    evaluator = AccuracyEvaluatorWithAmbiguity()
    corr_evaluator = CorrelationEvaluator()
    rmse_evaluator = RMSEEvaluator()
    auroc_evaluator = AuROCEvaluator()

    cre = CausalReasoningEngine(cd)
    cre.train()

    override_examples_in_dataset(engine, cd, task='causal')

    all_choice_scores, all_label_indices = [], []
    for ex in cd:
        choice_scores = cre.predict(ex)
        all_choice_scores.append(choice_scores)
        all_label_indices.append(ex.answer_dist)

    # evaluate
    acc, conf_interval = evaluator.evaluate(all_choice_scores, all_label_indices)
    r, p = corr_evaluator.evaluate(all_choice_scores, all_label_indices)
    rmse = rmse_evaluator.evaluate(all_choice_scores, all_label_indices)
    auroc = auroc_evaluator.evaluate(all_choice_scores, all_label_indices)

    print()
    print(f"engine: {engine}")
    print(f"Causal Accuracy: {acc:.4f} ({conf_interval[0]:.4f}, {conf_interval[1]:.4f})")
    print(f"Causal Correlation: {r:.4f} (p={p:.4f})")
    print(f"Causal RMSE: {rmse:.4f}")
    print(f"Causal AuROC: {auroc:.4f}")

    return ExperimentResult(acc, conf_interval, r, p, rmse, auroc)

def exp4_moral_ere_pred(engine='text-davinci-002'):
    # 1. load predicted factors from Exp 3
    # 2. Get factors out, replace the annotation with the predicted factors
    # 3. Use Exp 2 to predict abstract causal judgment

    md = MoralDataset()
    evaluator = AccuracyEvaluatorWithAmbiguity()
    corr_evaluator = CorrelationEvaluator()
    rmse_evaluator = RMSEEvaluator()
    auroc_evaluator = AuROCEvaluator()

    mre = MoralReasoningEngine(md)
    mre.train()

    override_examples_in_dataset(engine, md, task='moral')

    all_choice_scores, all_label_indices = [], []
    for ex in md:
        choice_scores = mre.predict(ex)
        all_choice_scores.append(choice_scores)
        all_label_indices.append(ex.answer_dist)

    # evaluate
    acc, conf_interval = evaluator.evaluate(all_choice_scores, all_label_indices)
    r, p = corr_evaluator.evaluate(all_choice_scores, all_label_indices)
    rmse = rmse_evaluator.evaluate(all_choice_scores, all_label_indices)
    auroc = auroc_evaluator.evaluate(all_choice_scores, all_label_indices)

    print()
    print(f"engine: {engine}")
    print(f"Moral Accuracy: {acc:.4f} ({conf_interval[0]:.4f}, {conf_interval[1]:.4f})")
    print(f"Moral Correlation: {r:.4f} (p={p:.4f})")
    print(f"Moral RMSE: {rmse:.4f}")
    print(f"Moral AuROC: {auroc:.4f}")

    return ExperimentResult(acc, conf_interval, r, p, rmse, auroc)

def produce_table5():
    causal_ere_result = {}
    for engine in ["text-babbage-001", 'text-curie-001', 'text-davinci-002']:
        causal_ere_result[engine] = exp4_causal_ere_pred(engine).json

    # json save
    with open('../../results/exp4_causal_ere_result.json', 'w') as f:
        json.dump(causal_ere_result, f, indent=2)

    moral_ere_result = {}
    for engine in ["text-babbage-001", 'text-curie-001', 'text-davinci-002']:
        moral_ere_result[engine] = exp4_moral_ere_pred(engine).json

    # json save
    with open('../../results/exp4_moral_ere_result.json', 'w') as f:
        json.dump(moral_ere_result, f, indent=2)

    causal_ttt_result = {}
    for engine in ["text-babbage-001", 'text-curie-001', 'text-davinci-002']:
        causal_ttt_result[engine] = exp4_causal_ttt_pred(engine).json

    # json save
    with open('../../results/exp4_causal_ttt_result.json', 'w') as f:
        json.dump(causal_ttt_result, f, indent=2)

    moral_ttt_result = {}
    for engine in ["text-babbage-001", 'text-curie-001', 'text-davinci-002']:
        moral_ttt_result[engine] = exp4_moral_ttt_pred(engine).json

    # json save
    with open('../../results/exp4_moral_ttt_result.json', 'w') as f:
        json.dump(moral_ttt_result, f, indent=2)

if __name__ == '__main__':
    ...
    # exp4_causal_ere_pred()
    # exp4_moral_ere_pred()

    produce_table5()
    # exp4_causal_ere_pred(engine='text-curie-001')

    # print(np.array(all_choice_scores)[5])
    # print(np.array(all_choice_scores_2)[5])
    # print(np.array(all_choice_scores) == np.array(all_choice_scores_2))