from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
from transformers import GPT2TokenizerFast, GPT2Tokenizer, LlamaTokenizer
from editor import apply_grace_to_model, GraceHyperParams, apply_wise_to_model, wise_config
from editor import ROMEHyperParams, apply_rome_to_model
from editor import FTHyperParams, apply_ft_to_model
from editor import MEMITHyperParams, apply_memit_to_model
from editor import LoRAHyperParams, apply_lora_to_model
from editor import MENDHyperParams, MendRewriteExecutor
from editor import DeferHyperParams, apply_defer_to_model
from editor import FTEWCHyperParams, apply_ft_ewc_to_model
from editor import ICEHyperParams, ICERewriteExecutor
from editor import compute_rewrite_or_rephrase_quality, compute_locality_quality
import torch
import json
from tqdm import tqdm
import logging
import random
import numpy as np
from typing import List
import os
import yaml
import copy
import math
from dataclasses import asdict
from utils import LOG, dictToObj, summary_metrics
from editor.util import nethook

os.environ['CUDA_VISIBLE_DEVICES'] = '3'


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def grace_edit(model, tok, request, hparams):
    edit_model, weights_copy = apply_grace_to_model(model, tok, request, hparams)
    return edit_model, weights_copy

def defer_edit(model, tok, request, hparams):
    edit_model, weights_copy = apply_defer_to_model(model, tok, request, hparams)
    return edit_model, weights_copy

def rome_edit(model, tok, request, hparams):
    edit_model, weights_copy = apply_rome_to_model(model, tok, request, hparams)
    return edit_model, weights_copy


def memit_edit(model, tok, request, hparams):
    if request is not list:
        request = [request]
    edit_model, weights_copy = apply_memit_to_model(model, tok, request, hparams)
    return edit_model, weights_copy

def memit_mass_edit(model, tok, requests, hparams):
    if type(requests) is not list:
        requests = [requests]
    edit_model, weights_copy = apply_memit_to_model(model, tok, requests, hparams)
    return edit_model, weights_copy


def mend_edit(model, tok, request, hparams):
    if type(request) is not list:
        request = [request]
    edit_model, weights_copy = MendRewriteExecutor().apply_to_model(model, tok, request, hparams)
    return edit_model, weights_copy


def ft_edit(model, tok, request, hparams):
    if request is not list:
        request = [request]
    edit_model, weights_copy = apply_ft_to_model(model, tok, request, hparams)
    return edit_model, weights_copy

def ft_ewc_edit(model, tok, request, hparams):
    edit_model, weights_copy = apply_ft_ewc_to_model(model, tok, request, hparams)
    return edit_model, weights_copy

def wise_edit(model, tok, request, hparams):
    edit_model, weights_copy = apply_wise_to_model(model, tok, request, hparams)
    return edit_model, weights_copy


def lora_edit(model, tok, request, hparams):
    if request is not list:
        request = [request]
    edit_model, weights_copy = apply_lora_to_model(model, tok, request, hparams)
    return edit_model, weights_copy


def process_vanilla_inputs(prompts, target_new, **kwargs):
    requests = [
        {"prompt": prompt,
         "target_new": target_new_,
         "id": i
         }
        for i, (prompt, target_new_) in enumerate(zip(prompts, target_new))]
    if 'rephrase_prompts' in kwargs and kwargs['rephrase_prompts']:
        for i, request in enumerate(requests):
            request['rephrase_prompt'] = kwargs['rephrase_prompts'][i]
    if 'ood_rephrases' in kwargs and kwargs['ood_rephrases']:
        for i, request in enumerate(requests):
            request['ood_rephrase'] = kwargs['ood_rephrases'][i]
    if 'act_prompts' in kwargs and kwargs['act_prompts']:
        for i, request in enumerate(requests):
            request['act_prompt'] = kwargs['act_prompts'][i]
    if 'subject' in kwargs and kwargs['subject']:
        for i, request in enumerate(requests):
            request['subject'] = kwargs['subject'][i]
    if 'locality_inputs' in kwargs and kwargs['locality_inputs']:
        locality_inputs = kwargs['locality_inputs']
        for locality_key in locality_inputs.keys():
            for i, request in enumerate(requests):
                if locality_inputs[locality_key]['prompt'][i] is not None:
                    if 'locality' not in request:
                        request['locality'] = dict()
                    request['locality'].update(
                        {
                            locality_key: {
                                f'prompt': locality_inputs[locality_key]['prompt'][i],
                                f'ground_truth': locality_inputs[locality_key]['ground_truth'][i]
                            }
                        }
                    )
    return requests


def edit_evaluation(all_metrics, request, edited_model, hparams, tok, idx, ppl_metric, icl_executor):
    metrics = {'post': {'locality': {}}}
    if hparams.alg_name == 'ICE':
        assert icl_executor is not None
        icl_content = icl_executor.infer_eval(request['prompt'])
        request['prompt'] = (icl_content + request['prompt'])
        if 'rephrase_prompt' in request:
            icl_content = icl_executor.infer_eval(request['rephrase_prompt'])
            request['rephrase_prompt'] = (icl_content + request['rephrase_prompt'])
        if 'ood_rephrase' in request:
            icl_content = icl_executor.infer_eval(request['ood_rephrase'])
            request['ood_rephrase'] = (icl_content + request['ood_rephrase'])
        if 'locality' in request:
            for locality_key in request['locality'].keys():
                icl_content = icl_executor.infer_eval(request['locality'][locality_key]['prompt'])
                request['locality'][locality_key]['prompt'] = (icl_content + request['locality'][locality_key]['prompt'])

    metrics['post'].update(
        compute_rewrite_or_rephrase_quality(edited_model, hparams, tok, request['prompt'], request['target_new'],
                                            hparams.device, eval_metric='ppl' if ppl_metric else 'token_em'))
    if 'rephrase_prompt' in request:
        metrics['post'].update(
            compute_rewrite_or_rephrase_quality(edited_model, hparams, tok, request['rephrase_prompt'],
                                                request['target_new'], hparams.device,
                                                test_rephrase=True, eval_metric='ppl' if ppl_metric else 'token_em'))
    if 'ood_rephrase' in request:
        metrics['post'].update(
            compute_rewrite_or_rephrase_quality(edited_model, hparams, tok, request['ood_rephrase'],
                                                request['target_new'], hparams.device,
                                                test_ood_rephrase=True,
                                                eval_metric='ppl'))

    if 'locality' in request.keys() and any(request['locality']):
        for locality_key in request['locality'].keys():
            metrics['post']['locality'].update(
                compute_locality_quality(edited_model, hparams, tok, locality_key,
                                         request['locality'][locality_key]['prompt'],
                                         request['locality'][locality_key]['ground_truth'], hparams.device)
            )

        assert len(metrics['post']['locality'][f'{locality_key}_output']) == len(
            all_metrics[idx]['pre']['locality'][f'{locality_key}_output'])
        locality_result = []
        for ans, label in zip(metrics['post']['locality'][f'{locality_key}_output'],
                              all_metrics[idx]['pre']['locality'][f'{locality_key}_output']):
            locality_result.append(np.mean(np.equal(ans, label)))
        metrics['post']['locality'][f'{locality_key}_acc'] = locality_result
    all_metrics[idx].update(metrics)

    if hparams.verbose:
        LOG.info(
            f"{idx} editing: {request['prompt']} -> {request['target_new']}  \n\n {all_metrics[idx]}"
        )


def run(config, **kwargs):
    set_seed(0)

    K = 1000
    mid_evaluation_freq = None
    ppl_metric = False
    if 'ZsRE' in config['ds_path']:  ## ZsRE
        loc_data = json.load(open('./data/ZsRE/zsre_mend_train.json', 'r', encoding='utf-8'))[:K]
        edit_data = json.load(open('./data/ZsRE/zsre_mend_eval.json', 'r', encoding='utf-8'))[:K]
        act_prompts = [edit_data_['loc'] + ' ' + edit_data_['loc_ans'] for edit_data_ in loc_data]
        prompts = [edit_data_['src'] for edit_data_ in edit_data]
        subject = [edit_data_['subject'] for edit_data_ in edit_data]
        rephrase_prompts = [edit_data_['rephrase'] for edit_data_ in edit_data]
        ood_rephrases = None
        target_new = [edit_data_['alt'] for edit_data_ in edit_data]
        locality_prompts = [edit_data_['loc'] for edit_data_ in edit_data]
        locality_ans = [edit_data_['loc_ans'] for edit_data_ in edit_data]

        locality_inputs = {
            'neighborhood': {
                'prompt': locality_prompts,
                'ground_truth': locality_ans
            },
        }
    elif config['ds_path'] in ['company', 'country', 'temporal', 'ood_sum']:  ## company_country
        ds_path = config['ds_path']
        loc_data = json.load(open(f'./data/{ds_path}/{ds_path}-train.json', 'r', encoding='utf-8'))[:K]
        edit_data = json.load(open(f'./data/{ds_path}/{ds_path}-edit.json', 'r', encoding='utf-8'))[:K]
        act_prompts = [edit_data_['locality_prompt'] + ' ' + edit_data_['locality_ground_truth'] for edit_data_ in loc_data]
        prompts = [edit_data_['prompt'] for edit_data_ in edit_data]
        subject = [edit_data_['subject'] for edit_data_ in edit_data]
        rephrase_prompts = None
        ood_rephrases = [edit_data_['ood_rephrase'] for edit_data_ in edit_data]
        target_new = [edit_data_['target_new'] for edit_data_ in edit_data]
        locality_prompts = [edit_data_['locality_prompt'] for edit_data_ in edit_data]
        locality_ans = [edit_data_['locality_ground_truth'] for edit_data_ in edit_data]

        locality_inputs = {
            'neighborhood': {
                'prompt': locality_prompts,
                'ground_truth': locality_ans
            },
        }
    elif 'counterfact' in config['ds_path']:  ## counterfact
        edit_data = json.load(open('./data/counterfact/counterfact-edit.json', 'r', encoding='utf-8'))[:K]
        act_prompts = [edit_data_['subject'] for edit_data_ in edit_data]
        subject = [edit_data_['subject'] for edit_data_ in edit_data]
        prompts = [edit_data_['prompt'] for edit_data_ in edit_data]
        rephrase_prompts = [edit_data_['rephrase_prompt'] for edit_data_ in edit_data]
        ood_rephrases = None
        target_new = [edit_data_['target_new'] for edit_data_ in edit_data]
        locality_prompts = [edit_data_['locality_prompt'] for edit_data_ in edit_data]
        locality_ans = [edit_data_['locality_ground_truth'] for edit_data_ in edit_data]

        locality_inputs = {
            'neighborhood': {
                'prompt': locality_prompts,
                'ground_truth': locality_ans
            },
        }
    elif 'hallucination' in config['ds_path']:  ## hallucination
        loc_data = json.load(open('./data/hallucination/hallucination-train.json', 'r', encoding='utf-8'))[:K]
        edit_data = json.load(open('./data/hallucination/hallucination-edit.json', 'r', encoding='utf-8'))[:K]
        act_prompts = [edit_data_['locality_prompt'] + ' ' + edit_data_['locality_ground_truth'] for edit_data_ in
                       loc_data]
        if len(act_prompts) < len(edit_data):
            act_prompts = (act_prompts * math.ceil(len(edit_data) / len(act_prompts)))[:K]
            random.shuffle(act_prompts)
        subject = [edit_data_['subject'] for edit_data_ in edit_data]
        prompts = [edit_data_['prompt'] for edit_data_ in edit_data]
        ood_rephrases = None
        rephrase_prompts = None
        target_new = [edit_data_['target_new'] for edit_data_ in edit_data]
        locality_prompts = [edit_data_['locality_prompt'] for edit_data_ in edit_data]
        locality_ans = [edit_data_['locality_ground_truth'] for edit_data_ in edit_data]

        locality_inputs = {
            'neighborhood': {
                'prompt': locality_prompts,
                'ground_truth': locality_ans
            },
        }

        ppl_metric = True

    requests = process_vanilla_inputs(prompts, target_new, rephrase_prompts=rephrase_prompts, ood_rephrases=ood_rephrases, act_prompts=act_prompts,
                                      subject=subject,
                                      locality_inputs=locality_inputs)

    if config['alg_name'] == 'GRACE':
        hparams = GraceHyperParams.from_hparams(config['hparams_file'])
        hparams = dictToObj({**asdict(hparams), **config})
        edit_func = grace_edit
    elif config['alg_name'] == 'ICE':
        hparams = ICEHyperParams.from_hparams(config['hparams_file'])
        hparams = dictToObj({**asdict(hparams), **config})
        embedding_tok = AutoTokenizer.from_pretrained(hparams.embedding_model, use_fast=False)
        embedding_model = AutoModel.from_pretrained(hparams.embedding_model).cpu()
        icl_executor = ICERewriteExecutor(embedding_tok, embedding_model)
        edit_func = icl_executor.apply_ice_to_model
    elif config['alg_name'] == 'Defer':
        hparams = DeferHyperParams.from_hparams(config['hparams_file'])
        hparams = dictToObj({**asdict(hparams), **config})
        edit_func = defer_edit
    elif config['alg_name'] == 'ROME':
        hparams = ROMEHyperParams.from_hparams(config['hparams_file'])
        hparams = dictToObj({**asdict(hparams), **config})
        edit_func = rome_edit
    elif config['alg_name'] == 'FT':
        hparams = FTHyperParams.from_hparams(config['hparams_file'])
        hparams = dictToObj({**asdict(hparams), **config})
        edit_func = ft_edit
    elif config['alg_name'] == 'FT_EWC':
        hparams = FTEWCHyperParams.from_hparams(config['hparams_file'])
        hparams = dictToObj({**asdict(hparams), **config})
        edit_func = ft_ewc_edit
    elif config['alg_name'] == 'MEMIT':
        hparams = MEMITHyperParams.from_hparams(config['hparams_file'])
        hparams = dictToObj({**asdict(hparams), **config})
        edit_func = memit_edit
    elif config['alg_name'] == 'MEMIT-MASS':
        hparams = MEMITHyperParams.from_hparams(config['hparams_file'])
        hparams = dictToObj({**asdict(hparams), **config})
        edit_func = memit_mass_edit
    elif config['alg_name'] == 'MEND':
        hparams = MENDHyperParams.from_hparams(config['hparams_file'])
        hparams = dictToObj({**asdict(hparams), **config})
        edit_func = mend_edit
    elif config['alg_name'] == 'LoRA':
        hparams = LoRAHyperParams.from_hparams(config['hparams_file'])
        hparams = dictToObj({**asdict(hparams), **config})
        edit_func = lora_edit
    elif config['alg_name'] == 'WISE':
        hparams = wise_config
        hparams = dictToObj({**asdict(hparams), **config})
        edit_func = wise_edit
    else:
        raise NotImplementedError

    # model = AutoModelForCausalLM.from_pretrained(hparams.model_name, device_map='auto').to(
    #     f'cuda:{hparams.device}')
    model = AutoModelForCausalLM.from_pretrained(hparams.model_name, device_map='auto')
    tok = AutoTokenizer.from_pretrained(hparams.model_name, use_fast=False)
    tok.pad_token_id = tok.eos_token_id

    if hparams.alg_name not in ['ROME', 'MEMIT', 'MEMIT-MASS']:
        LOG.info('AutoRegressive Model detected, set the padding side of Tokenizer to left...')
        tok.padding_side = 'left'
    else:
        tok.padding_side = 'right'

    os.makedirs(os.path.join(hparams.output_dir, hparams.ds_path), exist_ok=True)

    if hparams.alg_name == 'GRACE' or hparams.alg_name == 'Defer':
        pre_file = os.path.join(hparams.output_dir, hparams.ds_path, f'{hparams.model_name.split("/")[-1]}_grace_pre_results.json')
    else:
        pre_file = os.path.join(hparams.output_dir, hparams.ds_path, f'{hparams.model_name.split("/")[-1]}_pre_results.json')

    all_metrics = []
    if os.path.exists(pre_file):
        ### Store the pre_edit metric to refrain computing repeatedly
        all_metrics = json.load(open(pre_file, 'r'))
    else:
        for i, request in tqdm(enumerate(requests)):
            metrics = {'pre': {'locality': {}}}
            metrics['pre'].update(
                compute_rewrite_or_rephrase_quality(model, hparams, tok, request['prompt'], request['target_new'],
                                                    hparams.device, eval_metric='ppl' if ppl_metric else 'token_em'))
            if 'rephrase_prompt' in request:
                metrics['pre'].update(
                    compute_rewrite_or_rephrase_quality(model, hparams, tok, request['rephrase_prompt'],
                                                        request['target_new'], hparams.device,
                                                        test_rephrase=True,
                                                        eval_metric='ppl' if ppl_metric else 'token_em'))
            if 'ood_rephrase' in request:
                metrics['pre'].update(
                    compute_rewrite_or_rephrase_quality(model, hparams, tok, request['ood_rephrase'],
                                                        request['target_new'], hparams.device,
                                                        test_ood_rephrase=True,
                                                        eval_metric='ppl'))
            if 'locality' in request.keys() and any(request['locality']):
                for locality_key in request['locality'].keys():
                    metrics['pre']['locality'].update(
                        compute_locality_quality(model, hparams, tok, locality_key,
                                                 request['locality'][locality_key]['prompt'],
                                                 request['locality'][locality_key]['ground_truth'], hparams.device)
                    )
            all_metrics.append(metrics)

        ### Store the pre_edit metric to refrain computing repeatedly
        json.dump(all_metrics, open(pre_file, 'w'), indent=4)

    if hparams.sequential_edit:

        if hparams.alg_name == 'MEMIT-MASS':
            if mid_evaluation_freq is not None:
                for i in range(0, K, mid_evaluation_freq):
                    edited_model, weights_copy = edit_func(model, tok, requests[i:i+mid_evaluation_freq], hparams)
                    for j in range(i, i + mid_evaluation_freq):
                        edit_evaluation(all_metrics, requests[j], edited_model, hparams, tok, j, ppl_metric=ppl_metric, icl_executor=icl_executor if hparams.alg_name == 'ICE' else None)
                    with open(os.path.join(hparams.output_dir, hparams.ds_path, f'{"sequential" if hparams.sequential_edit else "single"}_{hparams.alg_name}_{K}_{i+mid_evaluation_freq}_{hparams.model_name.split("/")[-1]}_results.json'), 'w') as fw:
                        json.dump(all_metrics, fw, indent=4)

                    mean_metrics = summary_metrics(all_metrics[:i+mid_evaluation_freq])

                    with open(os.path.join(hparams.output_dir, hparams.ds_path, f'{"sequential" if hparams.sequential_edit else "single"}_{hparams.alg_name}_{K}_{i+mid_evaluation_freq}_{hparams.model_name.split("/")[-1]}_summary.json'), 'w') as fw:
                        json.dump(mean_metrics, fw, indent=4)
            else:
                edited_model, weights_copy = edit_func(model, tok, requests, hparams)
        else:
            for i, request in tqdm(enumerate(requests)):
                edited_model, weights_copy = edit_func(model, tok, request, hparams)

                if mid_evaluation_freq is not None and (i + 1) % mid_evaluation_freq == 0:
                    for j in range(i + 1):
                        edit_evaluation(all_metrics, requests[j], edited_model, hparams, tok, j, ppl_metric=ppl_metric, icl_executor=icl_executor if hparams.alg_name == 'ICE' else None)
                    with open(os.path.join(hparams.output_dir, hparams.ds_path, f'{"sequential" if hparams.sequential_edit else "single"}_{hparams.alg_name}_{K}_{i+1}_{hparams.model_name.split("/")[-1]}_results.json'), 'w') as fw:
                        json.dump(all_metrics, fw, indent=4)

                    mean_metrics = summary_metrics(all_metrics[:i+1])

                    with open(os.path.join(hparams.output_dir, hparams.ds_path, f'{"sequential" if hparams.sequential_edit else "single"}_{hparams.alg_name}_{K}_{i+1}_{hparams.model_name.split("/")[-1]}_summary.json'), 'w') as fw:
                        json.dump(mean_metrics, fw, indent=4)

        for i, request in tqdm(enumerate(requests)):
            edit_evaluation(all_metrics, request, edited_model, hparams, tok, i, ppl_metric=ppl_metric, icl_executor=icl_executor if hparams.alg_name == 'ICE' else None)
    else:
        for i, request in tqdm(enumerate(requests)):
            edited_model, weights_copy = edit_func(model, tok, request, hparams)
            edit_evaluation(all_metrics, request, edited_model, hparams, tok, i, ppl_metric=ppl_metric, icl_executor=icl_executor if hparams.alg_name == 'ICE' else None)
            if hparams.alg_name in ['GRACE', 'WISE', 'Defer']:
                weights_copy()
            elif hparams.alg_name == 'LoRA':
                edited_model.unload()
                del model.peft_config
            elif hparams.alg_name == 'ICE':
                icl_executor.roll_back_memory()
            else:
                with torch.no_grad():
                    for k, v in weights_copy.items():
                        nethook.get_parameter(model, k)[...] = v.to(f'cuda:{hparams.device}')

    with open(os.path.join(hparams.output_dir, hparams.ds_path, f'{"sequential" if hparams.sequential_edit else "single"}_{hparams.alg_name}_{K}_{hparams.model_name.split("/")[-1]}_results.json'), 'w') as fw:
        json.dump(all_metrics, fw, indent=4)

    mean_metrics = summary_metrics(all_metrics[:K])

    with open(os.path.join(hparams.output_dir, hparams.ds_path, f'{"sequential" if hparams.sequential_edit else "single"}_{hparams.alg_name}_{K}_{hparams.model_name.split("/")[-1]}_summary.json'), 'w') as fw:
        json.dump(mean_metrics, fw, indent=4)


if __name__ == '__main__':
    config = yaml.safe_load(open('./hparams/config.yaml', 'r'))
    run(config)
