import torch
import math
import json
import sys
import os

last_q = 64
arange = torch.arange(last_q, device="cuda")
LAST_Q_MASK = arange[None, None, :, None] >= arange[None, None, None, :]

def load(path, head):
    query, key = torch.load(path, "cpu")
    q, k = query[:,head,].unsqueeze(1).cuda(), key[:,head,].unsqueeze(1).cuda()
    return search_pattern(q, k)

def sum_all_diagonal_matrix(mat: torch.tensor): 
    b, h, n, m = mat.shape
    zero_mat = torch.zeros((b, h, n, n)).to(mat.device) # Zero matrix used for padding
    mat_padded =  torch.cat((zero_mat, mat, zero_mat), -1) # pads the matrix on left and right
    mat_strided = mat_padded.as_strided((1, 1, n, n + m), (1, n * (2 * n + m), 2 * n + m + 1, 1)) # Change the strides
    sum_diags = torch.sum(mat_strided, 2) # Sums the resulting matrix's columns
    return sum_diags[:,:,1:]

def search_pattern(q, k, idx=None):
    import numpy as np

    def vertical_and_slash(vertical_size, slash_size):
        last_q = 64
        qk = torch.einsum(f'bhmk, bhnk -> bhmn', q[:,:,-last_q:,:], k)
        qk[:, :, :, -last_q:] = torch.where(LAST_Q_MASK, qk[:, :, :, -last_q:], -torch.inf)    
        qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32)
        vertical = qk.sum(-2, keepdim=True)
        vertical[...,:30] = torch.inf
        vertical_topk = torch.topk(-vertical, q_len - vertical_size, -1).indices

        slash = sum_all_diagonal_matrix(qk)[...,:-last_q + 1]
        slash[...,-30:] = 10000
        slash_topk = slash
        slash = torch.topk(slash, slash_size, -1).indices - (q_len - 1)
        slash = torch.stack([torch.sparse.spdiags(torch.ones(slash_size, q_len), slash.cpu()[0][_], (q_len, q_len)).to_dense() for _ in range(1)]).to(q.device)
        
        est_attn = torch.ones_like(attn_weights)
        dim = 3
        est_attn = est_attn.scatter(3, vertical_topk.expand(*est_attn.shape[:dim], vertical_topk.shape[dim], *est_attn.shape[dim + 1 :]), 0)
        est_attn = est_attn + slash
        est_attn = (est_attn > 0).float()
        est_attn = torch.tril(est_attn)
        attn_weights_x = attn_weights * est_attn
        res3 = attn_weights_x[:,:,2500:].sum(-1).mean(-1).squeeze().float().detach().cpu().numpy()
        return res3

    def stream_llm(vertical_size, slash_size):
        q_len = q.shape[2]

        # import ipdb;ipdb.set_trace()
        mask = torch.triu(torch.tril(torch.ones(q_len, q_len), 0), -slash_size).to(q.device)
        mask[:,:vertical_size] = 1
        mask = mask.unsqueeze(0).unsqueeze(1)
        # est_attn = 
        torch.tril(mask)
        attn_weights_x = attn_weights * mask
        res3 = attn_weights_x[:,:,2500:].sum(-1).mean(-1).squeeze().float().detach().cpu().numpy()
        return res3

    def retrieval_head(topk_ratio, slash_size=None):
        block_num = (q_len -1) // 32 + 1
        block_q = torch.zeros(1,1,block_num * 32,128).to(q)
        block_q[:,:,:q_len] = q
        block_q = block_q.reshape(1,1,block_num,32,-1).mean(-2)
        block_k = torch.zeros(1,1,block_num * 32,128).to(k)
        block_k[:,:,:q_len] = k
        block_k = block_k.reshape(1,1,block_num,32,-1).mean(-2)

        arange = torch.arange(block_num, device="cuda")
        mask = arange[None, None, :, None] >= arange[None, None, None, :]

        qk = torch.matmul(block_q, block_k.transpose(2, 3))
        qk = torch.where(mask, qk, -torch.inf)  
        est_attn = torch.ones_like(qk)
        block_topk = torch.topk(-qk, block_num - block_num//topk_ratio, -1).indices
        
        dim = 3
        est_attn = est_attn.scatter(3, block_topk.expand(*est_attn.shape[:dim], block_topk.shape[dim], *est_attn.shape[dim + 1 :]), 0)
        est_attn = est_attn.unsqueeze(3).unsqueeze(-1).repeat(1,1,1,32,1,32).reshape(1,1,block_num * 32, block_num * 32)[...,:q_len,:q_len]
        est_attn = torch.tril(est_attn)

        attn_weights_x = attn_weights * est_attn
        res2 = attn_weights_x[:,:,2500:].sum(-1).mean(-1).squeeze().float().detach().cpu().numpy()
        return res2

    q_len = q.shape[2]
    arange = torch.arange(q_len, device="cuda")
    mask = arange[None, None, :, None] >= arange[None, None, None, :]
    attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(128)
    attn_weights = torch.where(mask, attn_weights, -torch.inf)  
    attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
 
    best_s, best_v, best_score, best_ty = 0, 0, 0, ""
    all_info = []
    for ty, fc in [("stream_llm", stream_llm), ("vertical_and_slash", vertical_and_slash), ("retrieval_head", retrieval_head)]:
    # for ty, fc in [("vertical_and_slash", vertical_and_slash), ("retrieval_head", retrieval_head)]:
        if ty == "stream_llm":
            # vs_list = [(100, 800)]
            vs_list = [(100, 800)]
        elif ty == "vertical_and_slash":
            vs_list = [(50, 800), (100, 750), (500, 700), (3000, 200)]
        else:
            vs_list = [(8, 1)]
        for v_size, s_size in vs_list:
            score = fc(v_size, s_size)
            score = score.item()
            all_info.append([ty, v_size, s_size, score])
            if score > best_score:
                best_score = score
                best_s, best_v = s_size, v_size
                best_ty = ty
            torch.cuda.empty_cache()

    print(best_ty, best_v, best_s, best_score)
    return all_info
    # return best_ty, best_v, best_s, 1


import json
from collections import defaultdict
def merge(paths):
    pattern_list = [json.load(open(path)) for path in paths]
    N = len(pattern_list[0])
    best_list = []
    for layer in range(N):
        best = {}
        for head in range(32):
            tmp = defaultdict(list)
            head_patterns = [i for pattern in pattern_list if layer < len(pattern) for i in pattern[layer][head]]
            for ty, v_size, s_size, score in head_patterns:
                tmp[(ty, v_size, s_size)].append(score)
            best_s, best_v, best_score, best_ty = 0, 0, 0, ""
            for (ty, v_size, s_size), score in tmp.items():
                if ty == "retrieval_head" and v_size != 8:
                    continue
                score = sum(score) / len(score)
                if score > best_score:
                    best_score = score
                    best_s, best_v = s_size, v_size
                    best_ty = ty
            # print(best_ty, best_v, best_s, best_score)
            best[str(head)] = (best_ty, best_v, best_s, best_score)
        best_list.append(best)
    return best_list

def compare_list(pattern1, pattern2, layer = 10):
    for head in range(32):
        if pattern1[layer][str(head)][:-1] != pattern2[layer][str(head)][:-1]:
            print(head, pattern1[layer][str(head)], pattern2[layer][str(head)])

path1 = "../eval/Llama_3_8B_Instruct_262k_kv_70k_v32_all_best_pattern.json"
path2 = "../eval/config/Llama_3_8B_Instruct_262k_kv_out_10k_v32_best_pattern.json"
# path3 = "../eval/Llama_3_8B_Instruct_262k_longbook_qa_chn_v32_best_pattern.json"
path4 = "../eval/Yi_9B_200k_kv_retrievl_v32_best_pattern.json"
# pattern1 = merge([path1])
# pattern2 = merge([path2])
# pattern3 = merge([path1, path2])
# pattern4 = merge([path1, path2, path3])
pattern5 = merge([path4])
# with open("../eval/Llama_3_8B_Instruct_262k_all3_v32_best_pattern.json", 'w') as json_file: 
#     json.dump(pattern4, json_file)
with open("../eval/Llama_3_8B_Instruct_262k_70k_v32_kv_3_best_pattern.json", 'w') as json_file: 
    json.dump(pattern1, json_file)

if __name__ == "__main__":
    layer_idx, head = sys.argv[1], sys.argv[2]
    path = "Llama_3_8B_Instruct_262k_kv_70k_v32_all_best_pattern.json"
    # path = "Yi_9B_200k_kv_128k_v32_3_all_best_pattern.json"
    if os.path.exists(path):
        best_list = json.load(open(path))
    else:
        best_list = []
    res = load(f"../eval/debug/{layer_idx}_kv_llama3_70k.pt", int(head))
    if int(head) == 0:
        best_list.append([res])
    else:
        best_list[-1].append(res)
    print(layer_idx, head, res)
    with open(path, 'w') as json_file: 
        json.dump(best_list, json_file)