import os
import argparse
import openai
import math
from tqdm import tqdm

from utils import *
from dataset_utils import read_synth_data, index_example, reorder_rationale
from joint import prompt_for_joint_prediction
import numpy as np
import matplotlib.pyplot as plt

openai.api_key = os.getenv("OPENAI_API_KEY")

from sklearn.metrics import roc_curve, auc, roc_auc_score

def _parse_args():
    parser = argparse.ArgumentParser()
    add_engine_argumenet(parser)

    # standard, instruction, etc
    parser.add_argument('--style', type=str, default="p-e", choices=["p-e", "p-e-r"])
    parser.add_argument('--run_prediction', default=False, action='store_true')
    parser.add_argument('--num_shot', type=int, default=16)
    parser.add_argument('--train_slice', type=int, default=0)
    parser.add_argument('--num_dev', type=int, default=250)
    parser.add_argument('--strategy', type=str, default="random", choices=["random"])
    parser.add_argument('--reorder', default=False, action='store_true')
    
    args = parser.parse_args()
    specify_engine(args)
    return args

def result_cache_name(args):
    return "misc/score_joint_{}_tr{}-{}_dv{}_{}_predictions.json".format(args.engine_name,
                    args.train_slice, args.train_slice + args.num_shot, args.num_dev, args.style)

def random_sample_strategy(ex, training_data, num_shot):
    return training_data[:num_shot]

def get_candidate_answers(ex):
    context = ex["context"]    
    clues = context.strip(".").split(",")
    candidates = [x.split()[0] for x in clues[0]]    
    return candidates

def parse_answer_and_rationale(text, style):
    text = text.strip()

    # place holder
    answer = "null"
    rationale = "null"
    
    if style == "p-e-r":
        sep = ', because '
        fields = text.split(sep)
        if len(fields) != 2:
            print("outlier", fields)
        answer, rationale = fields[0], fields[1]
    elif style == "p-e":
        sep = ', because '
        fields = text.split(sep)
        if len(fields) != 2:
            print("outlier", fields)
        answer, rationale = fields[0], fields[1]
    else:
        raise RuntimeError("Unsupported decoding style")
            
    return answer, rationale

def in_context_candidates_scoring(ex, training_data, engine, style="standard", strategy="random", num_shot=10):
    if strategy == "random":
        shots = random_sample_strategy(ex, training_data, num_shot)
    else:
        raise RuntimeError("Unsupported shot selection strategy")

    prompt, stop_signal = prompt_for_joint_prediction(ex, shots, style)    
    resp = openai.Completion.create(engine=engine, prompt=prompt, temperature=0.0, max_tokens=48, echo=True, logprobs=5, stop=stop_signal)        
    pred = resp["choices"][0]
    pred["text"] = pred["text"][len(prompt):]
    pred["completion_offset"] = len(prompt)
    
    natural_pred = pred
    natural_ans, _ = parse_answer_and_rationale(natural_pred["text"], style)
    # get candidates
    answer_candidates = get_candidate_answers(ex)

    packed_results = {}
    packed_results["natural"] = natural_pred
    packed_results["base_prompt"] = prompt
    candidate_preds = {}
    for candidate in answer_candidates:
        if candidate == natural_pred:
            candidate_preds[candidate] = natural_pred
            continue
        # expand prompt
        if style == "p-e" or style == "p-e-r":
            can_prompt = prompt + "A: " + candidate + "\n"
        else:
            raise RuntimeError("Unsupported decoding style")
        resp = openai.Completion.create(engine=engine, prompt=can_prompt, temperature=0.0, max_tokens=48, echo=True, logprobs=5, stop=stop_signal)
        pred = resp["choices"][0]        
        pred["text"] = pred["text"][len(prompt):]
        pred["completion_offset"] = len(prompt)

        candidate_preds[candidate] = pred
    packed_results["candidate_scores"] = candidate_preds
    return packed_results

def conditional_strip_prompt_prefix(x, p):
    if x.startswith(p):
        x = x[len(p):]
    return x.strip()

def test_few_shot_candidates_scoring(args):
    print("Running prediction")
    train_set = read_synth_data("data/100-train_synth.json")
    dev_set = read_synth_data("data/250-dev_synth.json")

    train_set = train_set[args.train_slice:(args.train_slice + args.num_shot)]
    dev_set = dev_set[:args.num_dev]
    
    train_set = [index_example(x) for x in train_set]
    dev_set = [index_example(x) for x in dev_set]
    predictions = []
    for x in tqdm(dev_set, total=len(dev_set), desc="Predicting"):
        predictions.append(in_context_candidates_scoring(x, train_set, engine=args.engine, style=args.style, strategy=args.strategy, num_shot=args.num_shot))    
    # save
    # read un indexed dev
    dump_json(predictions, result_cache_name(args))        

def evaluate_rationale_match(gt, p):
    if gt == p:
        return True
    if "." not in p:
        return False
    sent_boundary = p.index('.')
    fst_sent = p[:(sent_boundary + 1)]
    snd_sent = p[(sent_boundary + 2):]

    return (snd_sent + ' ' + fst_sent) == gt


def process_joint_prediction(p, style):
    text = p["text"]
    text = text.strip()

    # place holder
    answer = "null"
    rationale = "null"
    
    if style == "p-e-r":
        sep = ', because '
        fields = text.split(sep)
        if len(fields) != 2:
            print("outlier", fields)
        answer, rationale = fields[0], fields[1]
    elif style == "p-e":
        sep = ', because '
        fields = text.split(sep)
        if len(fields) != 2:
            print("outlier", fields)
        answer, rationale = fields[0], fields[1]
    else:
        raise RuntimeError("Unsupported decoding style")
    
    p["answer"] = answer
    p["rationale"] = rationale
    return p

# get answer_conf, explanation conf
def calc_completion_confidence(pred, style):
    if not (style == "standard" or style == "insta" or style == "instb"):
        raise RuntimeError("Unsupported decoding style")

    completion_offset = pred["completion_offset"]
    tokens = pred["logprobs"]["tokens"]
    token_offset = pred["logprobs"]["text_offset"]

    completion_start_tok_idx = token_offset.index(completion_offset)
    completion_end_tok_idx = tokens.index("<|endoftext|>")
    completion_tokens = tokens[completion_start_tok_idx:(completion_end_tok_idx)]
    completion_probs = pred["logprobs"]["token_logprobs"][completion_start_tok_idx:(completion_end_tok_idx)]

    div_pos = completion_tokens.index("\n")
    # answer span 2:div_pos
    # rationale span div_pos + 5:

    ans_log_probs = sum(completion_probs[2:div_pos])
    rat_log_probs = sum(completion_probs[div_pos + 5:])
    return ans_log_probs, rat_log_probs

def evaluate_joint_predictions(dev_set, predictions):
    qa_acc = 0
    set_acc = 0

    qa_matches = []
    for ex, pred in zip(dev_set, predictions):
        ans_gt = ex["answer"]
        rationale_gt = ex["text_rationale"]
        rationale_p = pred["rationale"]
        ans_p = pred["answer"]

        matched = evaluate_rationale_match(rationale_gt, rationale_p)
        set_acc += matched
        qa_acc += ans_gt == ans_p
        qa_matches.append(ans_gt == ans_p)

    print("QA ACC", qa_acc / len(predictions))
    print("Rationale Set ACC", set_acc / len(predictions))
    return qa_matches

def list_arg_max(k, v):
    best_idx = sorted(list(enumerate(v)), key=lambda x: x[1], reverse=True)[0][0] 
    return k[best_idx], v[best_idx]

def coverage_test(ex, result, style):
    if not (style == "p-e" or style == "p-e-r"):
        raise RuntimeError("Unsupported decoding style")

    natural_pred = result["natural"]
    completion_offset = natural_pred["completion_offset"]
    token_offset = natural_pred["logprobs"]["text_offset"]
    tokens = natural_pred["logprobs"]["tokens"]
    completion_start_tok_idx = token_offset.index(completion_offset)
    natural_start_token = natural_pred['logprobs']['tokens'][completion_start_tok_idx + 2]
    
    # +2 cuz 'Answer' ':'    
    ans_branches = list(natural_pred['logprobs']['top_logprobs'][completion_start_tok_idx + 2].keys())

    candidate_scores = result['candidate_scores']
    # clean branches
    ans_branches = [x for x in ans_branches if x.startswith(' ')]

    branch_scores = []
    branch_rat_validness = []
    for b in ans_branches:
        b = b.lstrip()
        if b not in candidate_scores:
            # TODO shortcut now, if not in candidate we just assume it being bad
            branch_scores.append(-10000.0)
            branch_rat_validness.append(False)
            continue
        process_joint_prediction(candidate_scores[b], style)
        joint_p = sum(calc_completion_confidence(candidate_scores[b], style))
        branch_scores.append(joint_p)

        b_rationale = candidate_scores[b]['rationale']
        b_rationale_clues = b_rationale.replace("and ", "and").split('and')
        branch_rat_validness.append(b_rationale_clues[0] in ex['context'] and b_rationale_clues[1] in ex['context'])
    best_branch, best_score = list_arg_max(ans_branches, branch_scores)
    best_branch_filtered, _ = list_arg_max(ans_branches, list(zip(branch_rat_validness, branch_scores)))

    searched = best_branch.lstrip() == ex['answer']
    filtered = best_branch_filtered.lstrip() == ex['answer']
    return  searched, filtered

def analyze_rejection_calib(args):
    dev_set = read_synth_data("data/250-dev_synth.json")
    dev_set = dev_set[:args.num_dev]
    results = read_json(result_cache_name(args))

    testing_results = []
    for ex, r in zip(dev_set, results):
        testing_results.append(coverage_test(ex, r, args.style))
    covered = [x[0] for x in testing_results]
    searched = [x[1] for x in testing_results]
    filtered = [x[2] for x in testing_results]

    print(sum(covered) / len(covered), sum(searched) / len(searched), sum(filtered) / len(filtered))

if __name__=='__main__':
    args = _parse_args()
    if args.run_prediction:
        test_few_shot_candidates_scoring(args)
    else:
        analyze_rejection_calib(args)
