from transformers import AutoConfig,AutoModelForCausalLM, AutoModelForMaskedLM, AutoTokenizer, top_k_top_p_filtering
import torch
import json
import pandas as pd
import numpy as np
import random
from sklearn.metrics import classification_report, accuracy_score
from tqdm import tqdm

import scipy.stats

from data import CausalDataset, MoralDataset, Example
from adapter import GPT3Adapter
from evaluator import AccuracyEvaluator
from prompt import CausalJudgmentPrompt, MoralJudgmentPrompt, CausalAbstractJudgmentPrompt, MoralAbstractJudgmentPrompt, \
                    CausalFactorPrompt, MoralFactorPrompt
from thought_as_text_translator import MoralTranslator, CausalTranslator
from tqdm import tqdm
import pickle

class HuggingFaceInterpreter:

    def __init__(self, model_name, device='cuda:0'):
        self.model_name = model_name
        self.config = AutoConfig.from_pretrained(model_name, output_hidden_states=True)
        if 'gpt' not in self.model_name:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.model = AutoModelForMaskedLM.from_pretrained(model_name).to(device)
        else:
            # only for GPT
            # https://github.com/huggingface/transformers/issues/3021
            self.tokenizer = AutoTokenizer.from_pretrained(model_name, pad_token='<|endoftext|>')
            self.tokenizer.padding_side = "right"
            self.model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

        self.vocab_size = self.tokenizer.vocab_size
        self.device = device

        if 'gpt' not in self.model_name:
            self.idx_no = self.tokenizer.decode(range(self.vocab_size)).index("no")
            self.idx_yes = self.tokenizer.decode(range(self.vocab_size)).index("yes")
        else:
            self.idx_no = self.tokenizer.decode(range(self.vocab_size)).index("No")
            self.idx_yes = self.tokenizer.decode(range(self.vocab_size)).index("Yes")

    def extract_input(self, sequences, encode_labels=False):
        if type(sequences) == str:
            sequences = [sequences]

        if 'gpt' not in self.model_name and not encode_labels:
            # MLM
            sequences = [s + ' ' + self.tokenizer.mask_token for s in sequences]
            input_ids = self.tokenizer(sequences, return_tensors="pt", truncation=True, padding=True).to(self.device)
        else:
            # LM
            input_ids = self.tokenizer.batch_encode_plus(sequences, padding=True, truncation=True,
                                                         return_tensors="pt").to(self.device)

        return input_ids

    # RETURNS A DICTIONARY {"yes":SCORE_YES, "no":SCORE_NO} AND AN INSTANCE OF THE MODEL
    def calculate_logits(self, sequences, lbls, no_grad=True):
        # lbls: ['Yes', 'Yes', 'No', ...]
        inputs = self.extract_input(sequences)

        if 'gpt' not in self.model_name:
            lbl_seqs = []
            for s, l in zip(sequences, lbls):
                lbl_seqs.append(s + ' ' + l)

            # check out "labels" in
            # https://huggingface.co/transformers/model_doc/bert.html#transformers.BertLMHeadModel
            tups = self.extract_input(lbl_seqs, encode_labels=True)
            lbls = tups['input_ids']
            # mask = tups['attention_mask']
            # lbls = lbls + (1 - mask) * -100  # make the padding -100, so in loss it's not computed
        else:
            # https://huggingface.co/transformers/model_doc/gpt2.html
            # lbls = inputs['input_ids']
            # it says the input is shifted...anyway, not sure about this part
            # mask = tups['attention_mask']
            # lbls = lbls + (1 - mask) * -100

            lbls = None

        if no_grad:
            with torch.no_grad():
                res = self.model(**inputs, labels=lbls, output_hidden_states=True)
        else:
            res = self.model(**inputs, labels=lbls, output_hidden_states=True)

        if 'gpt' not in self.model_name:
            mask_token_index = torch.where(inputs["input_ids"] == self.tokenizer.mask_token_id)

            token_logits = res.logits
            mask_token_logits = token_logits[mask_token_index]  # 2d mask will select 3d output
        else:
            attn_mask = inputs['attention_mask']
            lengths = [torch.sum(t) - 1 for t in attn_mask]
            token_logits = res.logits
            idx = torch.Tensor(list(range(len(lengths)))).long().to(self.device)
            # select the last index token
            mask_token_logits = token_logits[idx, torch.Tensor(lengths).long().to(self.device), :]

        # p_tokens = torch.nn.functional.softmax(mask_token_logits, dim=1)
        # log_p_tokens = torch.log(p_tokens)
        log_p_tokens = torch.nn.functional.log_softmax(mask_token_logits, dim=1)

        # we don't compare across models...all comparisons are within model
        # so it's fine that we don't normalize
        score_yes, score_no = log_p_tokens[:, self.idx_yes], log_p_tokens[:, self.idx_no]

        return score_yes, score_no

    def sample_answer(self, sequences, lbls, temp=1.0):
        # this is for the multi-run average
        score_yes, score_no = self.calculate_logits(sequences, lbls)
        p_yes, p_no = torch.exp(score_yes / temp), torch.exp(score_no / temp)
        p_yes, p_no = (p_yes) / (p_yes + p_no), (p_no) / (p_yes + p_no)

        # apply temperature scaling, temperature = 0.9 (make it high enough)
        # and use numpy to sample

        # keep in mind p_yes has [batch_size, 2]
        ps = torch.vstack([p_yes, p_no]).transpose(1, 0)
        answer = torch.distributions.Categorical(probs=ps).sample()

        return answer

    def calculate_gradient(self, score, res):
        grd = []
        try:
            grd = torch.autograd.grad(score, res.hidden_states, retain_graph=True)
        except:
            print('Gradient Error')
            return

        return grd

def compute_huggingface_sampling_mertrics(all_answers, full_label_list):
    model_name_to_scores = {}  # accuracy, prec, rec, f1
    for model_name, answers in all_answers.items():
        model_name_to_scores[model_name] = accuracy_score(full_label_list, answers)

    return model_name_to_scores

def mean_confidence_interval(data, confidence=0.95):
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
    return m, h

from collections import defaultdict
def average_huggingface_models(list_of_model_metrics):
    model_name_to_perf = defaultdict(list)
    for model_metrics in list_of_model_metrics:
        for model_name, score in model_metrics.items():
            model_name_to_perf[model_name].append(score)

    model_name_to_mean_ci = {}
    for model_name, scores in model_name_to_perf.items():
        model_name_to_mean_ci[model_name] = mean_confidence_interval(scores)

    return model_name_to_mean_ci

def sample_huggingface_answers_on_causal_prompt(causal_data_path='../../data/causal_prompt.json'):
    batch_size = 4
    temp = 1.5

    # 'EleutherAI/gpt-neo-1.3B'
    models = ['google/electra-large-generator', 'bert-base-uncased', 'bert-large-uncased', 'roberta-large',
              'albert-xxlarge-v2', 'gpt2-xl']
#     models = ['bert-base-uncased']

    data = json.load(open(causal_data_path))

    all_answers = {}

    for model_name in tqdm(models):
        # these models are too large to eval on GPU
        if model_name in ['roberta-large', 'albert-xxlarge-v2',
                          'gpt2-xl', "EleutherAI/gpt-neo-1.3B"]:
            model = HuggingFaceInterpreter(model_name, 'cpu')
        else:
            model = HuggingFaceInterpreter(model_name)

        full_answers = []
        full_label_list = []

        batch_texts, batch_labels = [], []
        for ex in tqdm(data['examples']):
            text = ex['input']
            corr_label = 'yes' if ex['target_scores']['Yes'] == 1 else 'no'

            batch_texts.append(text)
            batch_labels.append(corr_label)
            full_label_list.append(0 if ex['target_scores']['Yes'] == 1 else 1)

            if len(batch_texts) == batch_size:
                answers = model.sample_answer(batch_texts, batch_labels, temp=temp)
                full_answers.extend(answers.cpu().numpy().tolist())
                batch_texts, batch_labels = [], []

        if len(batch_texts) != 0:
            answers = model.sample_answer(batch_texts, batch_labels, temp=temp)
            full_answers.extend(answers.cpu().numpy().tolist())

        all_answers[model_name] = full_answers

    return all_answers, full_label_list

# This is the main function
def run_huggingface_models_on_causal(repeat=5, save_result_folder='../../data/model_outputs/'):
    list_of_model_metrics = []
    list_of_all_answers = []
    for _ in range(repeat):
        all_answers, full_label_list = sample_huggingface_answers_on_causal_prompt()
        model_metrics = compute_huggingface_sampling_mertrics(all_answers, full_label_list)
        list_of_model_metrics.append(model_metrics)
        list_of_all_answers.append(all_answers)

    print(average_huggingface_models(list_of_model_metrics))

    all_answers = list_of_all_answers[0]

    if save_result_folder != '':
        json.dump(all_answers,
                  open(save_result_folder + 'huggingface_model_answers_causal_original.json', 'w'))
        json.dump(full_label_list, open(save_result_folder + 'huggingface_model_labels_moral_original.json', 'w'))

def sample_huggingface_answers_on_moral_dilemma(moral_data_path='../../data/moral_dilemma.json'):
    batch_size = 4
    temp = 1.5

    # 'EleutherAI/gpt-neo-1.3B'
    models = ['google/electra-large-generator', 'bert-base-uncased', 'bert-large-uncased', 'roberta-large',
              'albert-xxlarge-v2', 'gpt2-xl']

    data = json.load(open(moral_data_path))

    all_answers = {}

    for model_name in tqdm(models):
        # these models are too large to eval on GPU
        if model_name in ['roberta-large', 'albert-xxlarge-v2',
                          'gpt2-xl', "EleutherAI/gpt-neo-1.3B"]:
            model = HuggingFaceInterpreter(model_name, 'cpu')
        else:
            model = HuggingFaceInterpreter(model_name)

        full_answers = []
        full_label_list = []

        batch_texts, batch_labels = [], []
        for ex in tqdm(data['examples']):
            text = ex['input']
            corr_label = 'yes' if ex['target_scores']['Yes'] == 1 else 'no'

            batch_texts.append(text)
            batch_labels.append(corr_label)
            full_label_list.append(0 if ex['target_scores']['Yes'] == 1 else 1)

            if len(batch_texts) == batch_size:
                answers = model.sample_answer(batch_texts, batch_labels, temp=temp)
                full_answers.extend(answers.cpu().numpy().tolist())
                batch_texts, batch_labels = [], []

        if len(batch_texts) != 0:
            answers = model.sample_answer(batch_texts, batch_labels, temp=temp)
            full_answers.extend(answers.cpu().numpy().tolist())

        all_answers[model_name] = full_answers

    return all_answers, full_label_list

# This is the main function
def run_huggingface_models_on_moral(repeat=5, save_result_folder='../../data/model_outputs/'):
    list_of_model_metrics = []
    list_of_all_answers = []
    for _ in range(repeat):
        all_answers, full_label_list = sample_huggingface_answers_on_moral_dilemma()
        model_metrics = compute_huggingface_sampling_mertrics(all_answers, full_label_list)
        list_of_model_metrics.append(model_metrics)
        list_of_all_answers.append(all_answers)

    print(average_huggingface_models(list_of_model_metrics))

    all_answers = list_of_all_answers[0]

    if save_result_folder != '':
        json.dump(all_answers,
                  open(save_result_folder + 'huggingface_model_answers_moral_original.json', 'w'))
        json.dump(full_label_list, open(save_result_folder + 'huggingface_model_labels_moral_original.json', 'w'))