from imports import *
from my_plotly import plot_attn_pattern_from_cache
from transformer_lens import FactoredMatrix
import transformer_lens
from jaxtyping import Float, Int
from transformer_lens.utils import composition_scores
#from rich.traceback import install
#install()


def get_ov(model, layer, head):
    return model.blocks[layer].attn.OV[head]

def get_qk(model,mov_lay, mov_head):
    q = model.blocks[mov_lay].attn.W_Q[mov_head]
    k = model.blocks[mov_lay].attn.W_K[mov_head]
    qb = model.blocks[mov_lay].attn.b_Q[mov_head]
    kb = model.blocks[mov_lay].attn.b_K[mov_head]
    W_Q_eff = torch.cat( [
            q,
            qb[ None, :],
        ],
        dim=0,
    )
    W_K_eff = torch.cat(
        [
            k,
            kb[None, :],
        ],
        dim=0,
    )

    return FactoredMatrix(W_Q_eff, W_K_eff.transpose(-1, -2))

def make_even(u, s, v):
    return FactoredMatrix(
            u * s.sqrt()[..., None, :],
            s.sqrt()[..., :, None] * transformer_lens.utils.transpose(v),
        )
def re_get_single_component(u, s, v, i):
    news = s.clone()
    newu = u#.clone()
    newv = v#.clone()
    news[:i] = 0
    #newu[:, :i] = 0
    #newv[:, :i] = 0
    if i != len(s)-1:
        news[i+1:] = 0
        #newu[:, i+1:] = 0
        #newv[:, i+1:] = 0
    return make_even(newu, news, newv)
def remove_components(u, s, v, dims):
    news = s.clone()
    news[dims] = 0
    return make_even(u.double(), news.double(), v.double())
def keep_components(u, s, v, dims):
    news = s.clone()
    for i in range(len(s)):
        if i not in dims:
            news[i] = 0
    return make_even(u, news, v)

def all_composition_scores(left, model, mode) -> Float[torch.Tensor, "n_layers n_heads"]:
    if mode == "Q":
        right = model.QK
    elif mode == "K":
        right = model.QK.T
    elif mode == "V":
        right = model.OV
    else:
        raise ValueError(f"mode must be one of ['Q', 'K', 'V'] not {mode}")
    print()
    scores = composition_scores(left, right, broadcast_dims=True)
    print("scores",scores.shape)
    # Mask scores to be zero for all pairs with the right head in the same layer or earlier
    # layer than the left head.
    mask = (
        torch.arange(model.cfg.n_layers, device=model.cfg.device)[:, None, None, None]
        < torch.arange(model.cfg.n_layers, device=model.cfg.device)[
            None, None, :, None
        ]
    )
    #scores = torch.where(mask, scores, torch.zeros_like(scores))
    return scores

def refactor_ov(W_O, W_V, b_O, b_V):
    #Based on: https://github.com/neelnanda-io/TransformerLens/blob/1062b187735f92ab6f0782b9e788d722c94e3962/transformer_lens/HookedTransformer.py#L1447

    effective_bias = b_O + einsum(
        "d_head, d_head d_model -> d_model", b_V, W_O
    )
    b_V_new = torch.zeros_like(b_V)
    b_O_new = effective_bias
    W_OV = FactoredMatrix(W_V, W_O)
    U, S, Vh = W_OV.svd()
    W_O_new, W_V_new = ov_from_usv(U, S, Vh)
    return W_O_new, W_V_new, b_O_new, b_V_new

def refactor_qk(W_Q, W_K, b_Q, b_K):
    #Based on: https://github.com/neelnanda-io/TransformerLens/blob/1062b187735f92ab6f0782b9e788d722c94e3962/transformer_lens/HookedTransformer.py#L1447
    # Concatenate biases to make a d_model+1 input dimension
    W_Q_eff = torch.cat( [
            W_Q,
            b_Q[ None, :],
        ],
        dim=0,
    )
    W_K_eff = torch.cat(
        [
            W_K,
            b_K[None, :],
        ],
        dim=0,
    )

    W_Q_eff_even, W_K_eff_even_T = (
        FactoredMatrix(W_Q_eff, W_K_eff.transpose(-1, -2)).make_even().pair
    )
    W_K_eff_even = W_K_eff_even_T.transpose(-1, -2)

    W_Q_new = W_Q_eff_even[:-1, :]
    b_Q_new =  W_Q_eff_even[-1, :]
    W_K_new = W_K_eff_even[:-1, :]
    b_K_new = W_K_eff_even[-1, :]
    return W_Q_new, W_K_new, b_Q_new, b_K_new




QKNAME='qk'
OVNAME='ov'
UNNAMED='unnamed'

class Head(FactoredMatrix):

    def __init__(self, A, B, layer: int, head: int, name: str = UNNAMED):
        super().__init__(A.clone(), B.clone())
        self._init_(layer, head, name)
        #print('Initialized Head', self.layer, self.head, self.name)

    def _init_(self, layer: int, head: int, name: str = UNNAMED):
        #Used when upgrading a factored matrix into a Head object
        self.layer = layer
        self.head = head
        self.name = name #either qk or ov or something


def norm(x, ord='fro'):
    return torch.linalg.norm(x, ord=ord)


def composition_score(a, b):
    return norm(torch.matmul(a, b))/(norm(a)*norm(b))

def get_single_component(u, s, v, i):
    return s[i].item()*torch.outer(u[:, i], v[i, :])

class HeadComposer:

    def __init__(self, model: HookedTransformer, rand_trials=50):
        self.model = model
        rand_comp_score = 0.
        for i in range(rand_trials):
            rand1 = torch.rand(model.cfg.d_model, model.cfg.d_model)
            rand2 = torch.rand(model.cfg.d_model, model.cfg.d_model)
            rand_comp_score += composition_score(rand1, rand2)
        self.rand_comp_score = rand_comp_score/max(rand_trials, 1)
        print('rand comp score', self.rand_comp_score)

    def confirm_head(self, head: FactoredMatrix, name: str):
        # This is just a sanity check for debugging. Can be removed later
        if type(head) == Head:
            assert head.name == name

    def convert_factmat_to_head(self, factmat: FactoredMatrix, layer: int, head: int, name: str = UNNAMED):
        head = Head(factmat.A, factmat.B, layer, head, name)
        return head

    def ov(self, layer, head):
        factmat = self.model.blocks[layer].attn.OV[head]
        head = self.convert_factmat_to_head(factmat, layer, head, OVNAME)
        return head

    def qk(self, layer, head):
        factmat = self.model.blocks[layer].attn.QK[head]
        head = self.convert_factmat_to_head(factmat, layer, head, QKNAME)
        return head

    def composed_head(self, later: Head, earlier: Head, name: str = UNNAMED):
        newname = f"{later.layer}.{later.head}{later.name}@{earlier.layer}.{earlier.head}{earlier.name}"
        return Head(later.AB, earlier.AB, later.layer, later.head, name=newname)

    def q_compose(self,  later_head: FactoredMatrix, earlier_head: FactoredMatrix):
        #wqk.T @ wov
        self.confirm_head(later_head, QKNAME)
        self.confirm_head(earlier_head, OVNAME)
        qk = later_head.T.AB
        ov = earlier_head.AB
        return composition_score(qk, ov)-self.rand_comp_score

    
    def k_compose(self, later_head: FactoredMatrix, earlier_head: FactoredMatrix):
        #wqk.T @ wov
        self.confirm_head(later_head, QKNAME)
        self.confirm_head(earlier_head, OVNAME)
        qk = later_head.AB
        ov = earlier_head.AB
        return composition_score(qk, ov)-self.rand_comp_score

    def v_compose(self, later_head: FactoredMatrix, earlier_head: FactoredMatrix):
        #wqk.T @ wov
        self.confirm_head(later_head, OVNAME)
        self.confirm_head(earlier_head, OVNAME) 
        ov2 = later_head.AB
        ov1= earlier_head.AB
        return composition_score(ov2, ov1)-self.rand_comp_score


    def q_compose_from_head_to_layer(self, src_layer, src_head, dest_layer): 
        #TODO: Switch order of dest and src and fix in all other functions
        src_head = self.ov(src_layer, src_head)
        compositions = []
        for head_idx in range(self.model.cfg.n_heads):
            dest_head = self.qk(dest_layer, head_idx)
            comp= self.q_compose(dest_head, src_head)-self.rand_comp_score
            compositions.append(comp)
        return torch.tensor(compositions)

    def k_compose_from_head_to_layer(self, src_layer, src_head, dest_layer):
        src_head = self.ov(src_layer, src_head)
        compositions = []
        for head_idx in range(self.model.cfg.n_heads):
            dest_head = self.qk(dest_layer, head_idx)
            comp= self.k_compose(dest_head, src_head)-self.rand_comp_score
            compositions.append(comp)
        return torch.tensor(compositions)

    def v_compose_from_head_to_layer(self, src_layer, src_head, dest_layer):
        src_head = self.ov(src_layer, src_head)
        compositions = []
        for head_idx in range(self.model.cfg.n_heads):
            dest_head = self.ov(dest_layer, head_idx)
            comp= self.v_compose(dest_head, src_head)-self.rand_comp_score
            compositions.append(comp)
        return torch.tensor(compositions)

    def compose_all_later_heads(self, src_layer, src_head, comptype):
        compositions = torch.zeros(self.model.cfg.n_layers, self.model.cfg.n_heads)
        for dest_layer in range(src_layer+1, self.model.cfg.n_layers):
            if comptype=='q':
                compositions[dest_layer, :] = self.q_compose_from_head_to_layer(src_layer, src_head, dest_layer)
            elif comptype=='k':
                compositions[dest_layer, :] = self.k_compose_from_head_to_layer(src_layer, src_head, dest_layer)
            elif comptype=='v':
                compositions[dest_layer, :] = self.v_compose_from_head_to_layer(src_layer, src_head, dest_layer)
            else:
                raise Exception(f"Pick a valid composition type: q, k, or v, you picked: {comptype}")
        return compositions

    def compose_svd_comps(self, dest_head: FactoredMatrix, src_head: FactoredMatrix, comptype: str):
        comptype=comptype.lower()
        ncomps = int(model.cfg.d_model/model.cfg.n_heads)

        def get_compscores(u2,s2,v2, u1,s1,v1):
            compscores = torch.zeros((ncomps, ncomps))
            for i in list(range(ncomps)):
                comp2 = get_single_component(u2, s2, v2, i)
                for j in range(ncomps):
                    comp1 = get_single_component(u1, s1, v1, j)
                    sim = composition_score(comp2, comp1)
                    compscores[i, j] = sim
            return compscores

        if comptype=='q':
            u2,s2,v2 = dest_head.T.svd()
        else:
            u2,s2,v2 = dest_head.svd()
        u1,s1,v1 = src_head.svd()
        compscores = get_compscores(u2,s2,v2.T, u1,s1,v1.T)
        return compscores

    def head_svd_reads_from_head(self, dest_head: FactoredMatrix, src_head: FactoredMatrix, comptype: str):
        comptype=comptype.lower()
        ncomps = int(model.cfg.d_model/model.cfg.n_heads)
        def get_compscores(u2,s2,v2, writer):
            compscores = torch.zeros((ncomps, 1))
            for i in list(range(ncomps)):
                comp2 = get_single_component(u2, s2, v2, i)
                sim = composition_score(comp2, writer)
                compscores[i] = sim
            return compscores

        if comptype=='q':
            u2,s2,v2 = dest_head.T.svd()
        else:
            u2,s2,v2 = dest_head.svd()
        writer = src_head.AB
        compscores = get_compscores(u2,s2,v2.T, writer)
        return compscores

def compose_all_heads_svd(model, composer):
    print("DEVICE", model.cfg.device, 'Composing SVD')
    q_path = f'exp_site/results/composed_heads/{model_name}/svd_q_comps'
    k_path = f'exp_site/results/composed_heads/{model_name}/svd_k_comps'
    v_path = f'exp_site/results/composed_heads/{model_name}/svd_v_comps'
    os.makedirs(q_path, exist_ok=True)
    os.makedirs(k_path, exist_ok=True)
    os.makedirs(v_path, exist_ok=True)

    for layer in range(model.cfg.n_layers-1):
        for head in range(model.cfg.n_heads):
            for later_layer in track(list(range(layer+1, model.cfg.n_layers)), 
                                    description=f'Composing all later heads with: layer {layer} head {head}'):
                print("Composing with heads in", later_layer)
                for later_head_idx in range(model.cfg.n_heads):
                    #cur head = layer, head
                    early_head = composer.ov(layer, head)
                    later_head = composer.qk(later_layer, later_head_idx)
                    q_comps = composer.compose_svd_comps(later_head, early_head, 'q').numpy()
                    k_comps = composer.compose_svd_comps(later_head, early_head, 'k').numpy()

                    later_head = composer.ov(later_layer, later_head_idx)
                    v_comps = composer.compose_svd_comps(later_head, early_head, 'v').numpy()

                    np.save(f"{q_path}/{layer}_{head}-{later_layer}_{later_head_idx}.npy", q_comps)
                    np.save(f"{k_path}/{layer}_{head}-{later_layer}_{later_head_idx}.npy", k_comps)
                    np.save(f"{v_path}/{layer}_{head}-{later_layer}_{later_head_idx}.npy", v_comps)

def compose_all_heads_read_svd(model, composer):
    print("DEVICE", model.cfg.device, 'Composing SVD')
    q_path = f'exp_site/results/composed_heads/{model_name}/read_svd_q_comps'
    k_path = f'exp_site/results/composed_heads/{model_name}/read_svd_k_comps'
    v_path = f'exp_site/results/composed_heads/{model_name}/read_svd_v_comps'
    os.makedirs(q_path, exist_ok=True)
    os.makedirs(k_path, exist_ok=True)
    os.makedirs(v_path, exist_ok=True)

    for layer in range(model.cfg.n_layers-1):
        for head in range(model.cfg.n_heads):
            for later_layer in track(list(range(layer+1, model.cfg.n_layers)), 
                                    description=f'Composing all later heads with: layer {layer} head {head}'):
                print("Composing with heads in", later_layer)
                for later_head_idx in range(model.cfg.n_heads):
                    #cur head = layer, head
                    early_head = composer.ov(layer, head)
                    later_head = composer.qk(later_layer, later_head_idx)
                    q_comps = composer.head_svd_reads_from_head(later_head, early_head, 'q').numpy()
                    k_comps = composer.head_svd_reads_from_head(later_head, early_head, 'k').numpy()

                    later_head = composer.ov(later_layer, later_head_idx)
                    v_comps = composer.head_svd_reads_from_head(later_head, early_head, 'v').numpy()

                    np.save(f"{q_path}/{layer}_{head}-{later_layer}_{later_head_idx}.npy", q_comps)
                    np.save(f"{k_path}/{layer}_{head}-{later_layer}_{later_head_idx}.npy", k_comps)
                    np.save(f"{v_path}/{layer}_{head}-{later_layer}_{later_head_idx}.npy", v_comps)

def compose_all_heads(model, composer):
    print("DEVICE", model.cfg.device)
    #composing with other heads
    #composing with mlp weight matrices
    #composing with unembeddings?
    os.makedirs(f'exp_site/results/composed_heads/{model_name}/q_comps', exist_ok=True)
    os.makedirs(f'exp_site/results/composed_heads/{model_name}/k_comps', exist_ok=True)
    os.makedirs(f'exp_site/results/composed_heads/{model_name}/v_comps', exist_ok=True)

    for layer in track(list(range(model.cfg.n_layers-1)), description='Iterating over layers'):
        for head in range(model.cfg.n_heads):
            #cur head = layer, head
            q_comps = composer.compose_all_later_heads(layer, head, 'q').numpy()
            k_comps = composer.compose_all_later_heads(layer, head, 'k').numpy()
            v_comps = composer.compose_all_later_heads(layer, head, 'v').numpy()
            np.save(f"exp_site/results/composed_heads/{model_name}/q_comps/{layer}_{head}.npy", q_comps)
            np.save(f"exp_site/results/composed_heads/{model_name}/k_comps/{layer}_{head}.npy", k_comps)
            np.save(f"exp_site/results/composed_heads/{model_name}/v_comps/{layer}_{head}.npy", v_comps)

if __name__ == "__main__":
    model_name = "attn-only-3l"
    model = HookedTransformer.from_pretrained(model_name)
    composer = HeadComposer(model, rand_trials=0)
    compose_all_heads_read_svd(model, composer)
    #compose_all_heads_svd(model, composer)
    