from imports import *
from utils import load_json
from torch.utils.data import DataLoader, Dataset

from tango import step
import time
import tango
from tango.common import FromParams
import termplotlib as tpl
from weights_composer import re_get_single_component, get_ov, remove_components
import sys
from exp_steps import (
    load_dataset,
    load_model,
    DataParams,
    ModelParams,
    get_token_idx
)
from laundry_list_exp import calc_inhib_score
from transformer_lens.hook_points import HookPoint
from transformer_lens.utils import get_act_name
from my_plotly import *
import plotly.graph_objects as go
from fancy_einsum import einsum
from tango.common import det_hash

from dataclasses import dataclass

def attn_result_hook(
        hook_vals: Float[torch.Tensor, "batch pos head_index d_model"],
        hook: HookPoint,
        head_idx: int,
        pos_idxs: list, #list of ints
        new_result_vecs: Float[torch.Tensor, 'batch d_model']
    ) -> Float[torch.Tensor, "batch pos head_index d_model"]:

    hook_vals[range(len(hook_vals)), pos_idxs, head_idx] = new_result_vecs
    return hook_vals

#TODO: Update this
@step(cacheable=True, deterministic=True, version='002')
def add_scaled_vec_inhib_scores(  
    model: ModelParams,
    dataset: DataParams,
    inhib_layer: int,
    inhib_head: int,
    comp_idx: int,
    mover_layer: int, 
    mover_head: int,
    scales:list,
    distractor_idx
    ) -> np.array:

    model=model.model
    model.set_use_attn_result(True)
    dataset = dataset.dataset

    comps = []

    ov = get_ov(model, inhib_layer, inhib_head)
    u, s, v = ov.svd()
    vec = v[:, comp_idx]

    inhib_scores = []
    #dataset = load_dataset(path=dataset_path, batch_size=20)
    def get_prompt(prompts, idx):
        newprompt = dict.fromkeys(prompts)
        #print("PROMPTS", prompts)
        distractors= prompts['distractors']
        for key in prompts:
            if key=='distractors':
                newprompt[key] = [d[idx] for d in distractors]
            else:
                newprompt[key] = prompts[key][idx]
        return newprompt

    for scale in scales:
        for batch in track(dataset):
            text = batch['text']
            model.reset_hooks()
            _, cache = model.run_with_cache(text, prepend_bos=True)
            if comp_idx != None:
                #print(batch)
                tokenized_text = model.to_str_tokens(text, prepend_bos=True)
                pos_idxs = []
                for prompt in text:
                    prompt_text = model.to_str_tokens(prompt, prepend_bos=True)
                    #print(prompt_text[-1])
                    pos_idxs.append(len(prompt_text)-1)

                values_to_add = vec*scale
                hook_fn = partial(attn_result_hook, pos_idxs=pos_idxs, head_idx=inhib_head, new_result_vecs=values_to_add)
                model.blocks[inhib_layer].attn.hook_result.add_hook(hook_fn)
                _, cache = model.run_with_cache(text, prepend_bos=True)
                model.reset_hooks()

            for batch_idx in range(len(text)):
                cur_prompt = get_prompt(batch, batch_idx)
                #print(cur_prompt)
                inhib_score = calc_inhib_score(model, cur_prompt, cache.apply_slice_to_batch_dim(batch_idx), mover_layer, mover_head, distract_idx=distractor_idx)
                inhib_scores.append(inhib_score.item())
                
    return np.array(inhib_scores)




#TODO: Update this
@step(cacheable=True, deterministic=True, version='001')
def single_example_attn_scores(  
    model: ModelParams,
    prompt: dict,
    inhib_layer: int,
    inhib_head: int,
    comp_idx: int,
    mover_layer: int, 
    mover_head: int,
    scale:int,
    ) -> list:

    model=model.model
    model.set_use_attn_result(True)

    comps = []

    ov = get_ov(model, inhib_layer, inhib_head)
    u, s, v = ov.svd()
    vec = v[:, comp_idx]

    label_attn = []
    distractors_attn = []

    text = prompt['text']
    model.reset_hooks()
    _, cache = model.run_with_cache(text, prepend_bos=True)
    #print(batch)
    tokenized_text = model.to_str_tokens(text, prepend_bos=True)
    pos_idxs = []
    pos_idxs.append(len(tokenized_text)-1)
    values_to_add = vec*scale
    hook_fn = partial(attn_result_hook, pos_idxs=pos_idxs, head_idx=inhib_head, new_result_vecs=values_to_add)
    model.blocks[inhib_layer].attn.hook_result.add_hook(hook_fn)
    _, cache = model.run_with_cache(text, prepend_bos=True)
    model.reset_hooks()

    attn_pat = cache['pattern', mover_layer, 'attn'][0]
    label_token = ' '+prompt['label']
    str_tokens = model.to_str_tokens(prompt['text'])
    last_tok = len(tokenized_text)-1
    label_idx = get_token_idx(tokenized_text, label_token)
    #print(tokenized_text, pos_idxs, label_token, label_idx)
    label_attn.append(attn_pat[mover_head, last_tok, label_idx].item())

    for dist in prompt['distractors']:
        dist_token = ' '+dist
        dist_token_idx = get_token_idx(tokenized_text, dist_token)
        distractors_attn.append(attn_pat[mover_head, last_tok, dist_token_idx].item())
                
    return np.array(label_attn), np.array(distractors_attn)



"""
{
    "text": " Tomorrow, when I go to the store, I will buy a plate, a pear, and a pen. First, I will get the plate, the pen, and then the",
    "n_objs": 3,
    "query_idx": 1,
    "objects": [
      "plate",
      "pear",
      "pen"
    ]
  },
"""

def topk(model, logits, k=10):
    return model.to_str_tokens(logits.topk(k).indices[:, -1], prepend_bos=False)

def object_in_highest_position(model, logits, objs):
    highest_obj = None
    top10_labels = topk(model, logits, 20)
    highest_idx = 9999
    #print(top10_labels)
    for obj in objs:
        if obj in top10_labels:
            if top10_labels.index(obj) < highest_idx:
                highest_idx = top10_labels.index(obj)
                highest_obj = obj
    return highest_obj, highest_idx


@dataclass
class AnswerAttns:
    label_attn: list
    distractors_attn: list

    def to_dict(self):
        return self.__dict__

@dataclass
class TopPredictions:
    top_preds: list
    obj_list: list
    best_idx: int
    top_pred: str

    def to_dict(self):
        return self.__dict__


@step(cacheable=True, deterministic=True, version='002')
def single_example_attn_scores_v2(  
    model: ModelParams,
    prompt: dict,
    inhib_layer: int,
    inhib_head: int,
    comp_idx: int,
    mover_layer: int, 
    mover_head: int,
    scale:int,
    ) -> tuple:

    model=model.model
    model.set_use_attn_result(True)

    comps = []

    ov = get_ov(model, inhib_layer, inhib_head)
    u, s, v = ov.svd()
    vec = v[:, comp_idx]

    label_attn = []
    distractors_attn = []

    text = prompt['text']
    model.reset_hooks()
    tokenized_text = model.to_str_tokens(text, prepend_bos=True)
    pos_idxs = []
    pos_idxs.append(len(tokenized_text)-1)
    values_to_add = vec*scale
    hook_fn = partial(attn_result_hook, pos_idxs=pos_idxs, head_idx=inhib_head, new_result_vecs=values_to_add)
    model.blocks[inhib_layer].attn.hook_result.add_hook(hook_fn)
    logits, cache = model.run_with_cache(text, prepend_bos=True)
    model.reset_hooks()

    query_idx = prompt['query_idx']

    attn_pat = cache['pattern', mover_layer, 'attn'][0]
    label_token = ' '+prompt['objects'][query_idx]

    str_tokens = model.to_str_tokens(prompt['text'])
    last_tok = len(tokenized_text)-1
    label_idx = get_token_idx(tokenized_text, label_token)
    #print(tokenized_text, pos_idxs, label_token, label_idx)
    label_attn.append(attn_pat[mover_head, last_tok, label_idx].item())

    distractors = prompt['objects']
    distractors.remove(label_token.strip())
    for dist in distractors:
        dist_token = ' '+dist
        dist_token_idx = get_token_idx(tokenized_text, dist_token)
        distractors_attn.append(attn_pat[mover_head, last_tok, dist_token_idx].item())
                
    ansattns = AnswerAttns(label_attn, distractors_attn)
    top_pred_obj, pred_idx = object_in_highest_position(model, logits, prompt['objects'])
    top_tokens = topk(model, logits, 20)
    toppreds = TopPredictions(top_tokens, prompt['objects'], pred_idx, top_pred_obj)

    return ansattns, toppreds

@dataclass
class ObjAttns:
    obj_attns: list
    query_idx: int
    objs: list
    pred_idxs: list #ranks of the objs in the final predictions

    def to_dict(self):
        return self.__dict__


def idxs_of_objs(model, ex, logits):
    objs = ex['objects']
    toks =  [model.to_single_token(' '+obj) for obj in objs]
    #argsort logits
    idxs = logits.argsort(descending=True)
    #print(idxs)
    #print([model.tokenizer.decode([t]) for t in idxs[:10]])
    #get the index of the object tokens
    obj_idxs = [torch.where(idxs == t)[0].item() for t in toks]
    return obj_idxs

@step(cacheable=True, deterministic=True, version='007')
def attn_scores_v2(  
    model: ModelParams,
    dataset: DataParams,#prompt: dict,
    inhib_layer: int,
    inhib_head: int,
    comp_idx: int,
    mover_layer: int, 
    mover_head: int,
    scale:int,
    ) -> list:

    model=model.model
    model.set_use_attn_result(True)

    comps = []

    ov = get_ov(model, inhib_layer, inhib_head)
    u, s, v = ov.svd()
    vec = v[:, comp_idx]

    all_outputs = []

    def get_prompt(prompts, idx):
        newprompt = dict.fromkeys(prompts)
        #print("PROMPTS", prompts)
        objs= prompts['objects']
        for key in prompts:
            if key=='objects':
                newprompt[key] = [d[idx] for d in objs]
            else:
                newprompt[key] = prompts[key][idx]
        return newprompt

    for batch in track(dataset.dataset):
        text = batch['text']
        model.reset_hooks()

        tokenized_text = model.to_str_tokens(text, prepend_bos=True)
        pos_idxs = []
        for prompt in text:
            prompt_text = model.to_str_tokens(prompt, prepend_bos=True)
            pos_idxs.append(len(prompt_text)-1)
        values_to_add = vec*scale
        hook_fn = partial(attn_result_hook, pos_idxs=pos_idxs, head_idx=inhib_head, new_result_vecs=values_to_add)
        model.blocks[inhib_layer].attn.hook_result.add_hook(hook_fn)
        logits, cache = model.run_with_cache(text, prepend_bos=True)
        model.reset_hooks()

        for batch_idx in range(len(text)):
            cur_prompt = get_prompt(batch, batch_idx)
            query_idx = cur_prompt['query_idx'].item()

            label_attn = []
            distractors_attn = []

            attn_pat = cache['pattern', mover_layer, 'attn'][batch_idx]
            label_token = ' '+cur_prompt['objects'][query_idx]
            #print(cur_prompt)
            str_tokens = model.to_str_tokens(cur_prompt['text'], prepend_bos=True)
            last_tok = len(str_tokens)-1

            distractors = cur_prompt['objects']
            for dist in distractors:
                dist_token = ' '+dist
                dist_token_idx = get_token_idx(str_tokens, dist_token)
                distractors_attn.append(attn_pat[mover_head, last_tok, dist_token_idx].item())
                        
            pred_idxs = idxs_of_objs(model, cur_prompt, logits[batch_idx, last_tok])
            output = ObjAttns(distractors_attn, query_idx, cur_prompt['objects'], pred_idxs)

            all_outputs.append(output)

    return all_outputs

@step(cacheable=True, deterministic=True, version='004')
def single_example_generation(  
    model: ModelParams,
    prompt: dict,
    inhib_layer: int,
    inhib_head: int,
    comp_idx: int,
    mover_layer: int, 
    mover_head: int,
    scale:int,
    max_new_tokens: int = 10,
    temperature: float = 0.0
    ) -> list:

    model=model.model
    model.set_use_attn_result(True)

    comps = []

    ov = get_ov(model, inhib_layer, inhib_head)
    u, s, v = ov.svd()
    vec = v[:, comp_idx]

    label_attn = []
    distractors_attn = []

    text = prompt['text']
    model.reset_hooks()
    _, cache = model.run_with_cache(text, prepend_bos=True)
    #print(batch)
    tokenized_text = model.to_str_tokens(text, prepend_bos=True)
    pos_idxs = []
    pos_idxs.append(len(tokenized_text)-1)
    values_to_add = vec*scale
    hook_fn = partial(attn_result_hook, pos_idxs=pos_idxs, head_idx=inhib_head, new_result_vecs=values_to_add)
    model.blocks[inhib_layer].attn.hook_result.add_hook(hook_fn)
    out = model.generate(
        text, 
        prepend_bos=True, 
        use_past_kv_cache=False, 
        max_new_tokens=max_new_tokens, 
        temperature = temperature,
        verbose=False)
    model.reset_hooks()

    return out


if __name__ == "__main__":
    inhib_layer, inhib_head = int(sys.argv[1]), int(sys.argv[2])
    mover_layer, mover_head = int(sys.argv[3]), int(sys.argv[4])

    ws = tango.Workspace.from_url("/oscar/data/epavlick/jmerull1/weights/tango_workspace")
    #from tango.workspaces.memory_workspace import MemoryWorkspace
    #ws = MemoryWorkspace()
    model_name = 'gpt2-small'
    #dataset_path = 'datasets/laundry_6item_reg1.json'
    dataset_path = 'datasets/laundry_list_3objs.json'
    
    
    dataset_params = DataParams(dataset_path, batch_size=1, extra_descriptor='3_ex_pilot')
    model_params = ModelParams(model_name)

    try:
        comp_idx = {'8.6':2, '7.3':1, '7.9':6, '8.10':1}[f'{inhib_layer}.{inhib_head}']
    except:
        print("No preset comp idx, using 0")
        comp_idx = 0
    print("COMP IDX", comp_idx)

    def scale_exp():
        for scale in np.arange(-50., 50., 1.):
            inhib_scores = add_scaled_vec_inhib_scores(
                model=model_params,
                dataset=dataset_params,
                inhib_layer=inhib_layer,
                inhib_head=inhib_head, 
                comp_idx=comp_idx, 
                mover_layer=mover_layer,
                mover_head=mover_head, 
                scales=[scale],
                distractor_idx=0
            ).result(ws)
            print("Scale:", scale, "Score:", inhib_scores.mean())

    def single_ex_exp(example):
        for i in np.arange(-100., 101., 1.):
            lab, dist = single_example_attn_scores(
                model=model_params,
                prompt=example,
                inhib_layer=inhib_layer,
                inhib_head=inhib_head,
                comp_idx=comp_idx,
                mover_layer=mover_layer,
                mover_head=mover_head,
                scale=i
            ).result(ws)
            print(i,':',lab, dist)

    def single_ex_exp_v2(example):
        output_data = {}
        for i in track(np.arange(-50., 51., 1.)):
            key_idx = int(i)
            output_data = {key_idx:{}}
            ansattns, toppreds = single_example_attn_scores_v2(  
                model=model_params,
                prompt=example,
                inhib_layer=inhib_layer,
                inhib_head=inhib_head,
                comp_idx=comp_idx,
                mover_layer=mover_layer, 
                mover_head=mover_head,
                scale=i,
            ).result(ws)
            #update the output dict to include the outputs from above
            output_data[key_idx].update(ansattns.to_dict())
            output_data[key_idx].update(toppreds.to_dict())
            rich.print(key_idx,':',ansattns.label_attn[0], ansattns.distractors_attn)
        return output_data

    print(dataset_params.dataset.dataset[0])
    print("PROMPT hash", det_hash(dataset_params.dataset.dataset[0]))

    def accuracy(outputs):
        return 100*sum([a.pred_idxs[a.query_idx] == 0 for a in outputs])/len(outputs)

    def label_attn(outputs):
        return np.mean([a.obj_attns[a.query_idx] for a in outputs])
    
    def all_obj_attns(outputs):
        return np.mean([a.obj_attns for a in outputs], axis=0)
    
    for i in [2,3,4,5,6,7,8,9,10]:
        output_data = {'config': 
            {
                'inhib_layer': inhib_layer,
                'inhib_head': inhib_head, 
                'comp_idx': comp_idx, 
                'mover_layer': mover_layer, 
                'mover_head': mover_head
            },
            'data': []
        }
        print(f"{i} items")
        dataset_path = f'datasets/laundry_list_250_{i}objs.json'
        output_path = f'exp_site/results/laundry_list/250_{i}objs_{inhib_layer}.{inhib_head}.{comp_idx}_{mover_layer}.{mover_head}.json'
        dataset_params = DataParams(dataset_path, batch_size=25, extra_descriptor=f'{i} objs LL')
        all_outputs = []
        print(torch.cuda.is_available(), 'cuda')
        query_idxs = [prompt['query_idx'] for prompt in dataset_params.dataset.dataset]
        for i in np.arange(-100., 101., 1.):
            rich.print("Scale", i)
            all_outputs_scale = attn_scores_v2(
                model=model_params,
                dataset=dataset_params,
                inhib_layer=inhib_layer,
                inhib_head=inhib_head,
                comp_idx=comp_idx,
                mover_layer=mover_layer, 
                mover_head=mover_head,
                scale=i,
            ).result(ws)
            rich.print(i,':\n', label_attn(all_outputs_scale), 'avg. label attn.\n', accuracy(all_outputs_scale), '% Acc.')
            #rich.print(all_outputs_scale[:10])
            rich.print('mean obj attns', all_obj_attns(all_outputs_scale))
            rich.print('\n~~~~~~~~~~~~~~~~~~~~~~~~~~')
            all_outputs.append({int(i):[a.to_dict() for a in all_outputs_scale]})
        

        with open(output_path, 'w') as f:
            json.dump(all_outputs, f)
    
    #dataset_path = 'datasets/owt_7.9_inhib_examples.json'
    #dataset_params = DataParams(dataset_path, batch_size=1, extra_descriptor='wh movement example v2')
    #single_ex_exp(example=dataset_params.dataset.dataset[0])


    def single_ex_gen(example):
        for i in np.arange(-100, 105, 5.):
            output = single_example_generation(
                model=model_params,
                prompt=example,
                inhib_layer=inhib_layer,
                inhib_head=inhib_head,
                comp_idx=comp_idx,
                mover_layer=mover_layer,
                mover_head=mover_head,
                scale=i,
                max_new_tokens=15,
                temperature=0.0
            ).result(ws)
            print(i,':',output)

    #single_ex_gen(example=dataset_params.dataset.dataset[1])


    #dataset_path = "datasets/owt_8.10_rep_list_examples.json"
    #dataset_params = DataParams(dataset_path, batch_size=1, extra_descriptor='repeating lists')
    #print('repeating lists')
    #print("Example:", dataset_params.dataset.dataset[1])
    #single_ex_exp(example=dataset_params.dataset.dataset[1])