import transformers
import torch
import numpy as np
import random
import seaborn as sns
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.gridspec as gridspec
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from pprint import pprint
import pickle
import json
import typing
from pathlib import Path
from matplotlib.font_manager import FontProperties
import matplotlib.pylab as pylab
import sys

from torch.utils.data import Dataset, DataLoader
import argparse

import re

def string_to_filename(input_string, max_length=255, default_name="default"):
    # Define a pattern to match invalid characters
    invalid_chars = r'[<>:"/\\|?*]'
    
    # Replace invalid characters with an underscore
    sanitized_string = re.sub(invalid_chars, '_', input_string)
    
    # Truncate the string to the maximum length if necessary
    sanitized_string = sanitized_string[:max_length].rstrip()
    
    # Ensure the filename is not empty
    if not sanitized_string:
        sanitized_string = default_name
    
    return sanitized_string

def plot_graphs(dump_data, filename, title, plot_prob_diff = False):
    params = {'legend.fontsize': 'x-large',
            'figure.figsize': (8, 6),
            'axes.labelsize': 'x-large',
            'xtick.labelsize':'x-large',
            'ytick.labelsize':'x-large'}
    pylab.rcParams.update(params)

    num_prepend_plus_one = max([len(v["acc"]) for k, v in dump_data.items()])
    
    # Plotting accuracy
    for model_name, model_dump_data in dump_data.items():
        acc, acc_lo, acc_hi = [], [], []
        for num_prep in range(num_prepend_plus_one):
            logits_true = np.array(model_dump_data['all_logitis_true'][num_prep])
            logits_false = np.array(model_dump_data['all_logitis_false'][num_prep])
            bool_arr = np.zeros(len(logits_true))
            bool_arr[: np.sum(logits_true > logits_false)] = 1.
            mu, err = np.mean(bool_arr), np.std(bool_arr) / np.sqrt(len(logits_true))
            acc.append(mu)
            acc_lo.append(mu - err)
            acc_hi.append(mu + err)

        x = list(range(num_prepend_plus_one))
        sns.lineplot(x=x, y=acc, label=model_name)
        plt.fill_between(x, acc_lo, acc_hi, alpha=0.2)

    plt.xlabel('Number of prepends', fontsize=20)
    plt.ylabel('Efficacy Score', fontsize=20)
    plt.title(title, fontsize=20)
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f"results/{filename}.png")
    plt.clf()

    if not plot_prob_diff:
        return

    for model_name, model_dump_data in dump_data.items():
        prob, prob_lo, prob_hi = [], [], []
        for num_prep in range(num_prepend_plus_one):
            probs_true = np.array(model_dump_data['all_probs_true'][num_prep])
            probs_false = np.array(model_dump_data['all_probs_false'][num_prep])
            probs_diffs = probs_true - probs_false
            mu, err = np.mean(probs_diffs), np.std(probs_diffs) / np.sqrt(len(probs_true))
            prob.append(mu)
            prob_lo.append(mu - err)
            prob_hi.append(mu + err)

        x = list(range(num_prepend_plus_one))
        sns.lineplot(x=x, y=prob, label=model_name)
        plt.fill_between(x, prob_lo, prob_hi, alpha=0.2)

    plt.xlabel('Number of prepends', fontsize=20)
    plt.ylabel('Efficacy Magnitude', fontsize=20)
    plt.title(title, fontsize=20)
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f"results/{filename}_probs.png")
    plt.clf()

def main(): 
    model_names = ["openai-community/gpt2", "google/gemma-2b", "google/gemma-2b-it", "meta-llama/Llama-2-7b-hf"]
    sub_rel_ids = ["P190", "P103", "P641", "P131"]
    for rel_id in sub_rel_ids:
        filename = f"_sentence_false_{rel_id}"
        false_accs = {}
        for name in model_names:
            with open('results/' + string_to_filename(name) + f'{filename}.pickle', 'rb') as handle:
                false_accs[name] = pickle.load(handle)

        plot_graphs(false_accs, filename = filename, title = f"Hijacking based on {rel_id}", plot_prob_diff = True)

        filename = f"sentence_true_{rel_id}"
        false_accs = {}
        for name in model_names:
            with open('results/' + string_to_filename(name) + f'{filename}.pickle', 'rb') as handle:
                false_accs[name] = pickle.load(handle)
        plot_graphs(false_accs, filename = filename, title = f"Prepending answer based on {rel_id}", plot_prob_diff = True)

if __name__ == "__main__":
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    main()