from __future__ import annotations
import json
import os
from pathlib import Path
import time
from typing import List, Tuple, Any
from tqdm import tqdm

import torch
from torch import Tensor
from transformers import (
    AutoConfig,
    AutoTokenizer,
    AutoModelForCausalLM,
    GenerationConfig,
    LlamaForCausalLM
)
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.cache_utils import SinkCache

from eval_utils import (
    check_benchmark_availability,
    dump_jsonl,
    create_prompt,
    load_data,
    get_answer,
    LONGBENCH_DATA_NAME_TO_MAX_NEW_TOKENS,
    create_longbench_prompt,
)
from metrics import compute_scores
from args import parse_args
from sparq import *
from datasets import load_dataset
from vllm import LLM, SamplingParams

from patch import patch_hf, minference_patch

# sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
def truncate_input(input: list, max_length: int, manner="middle"):
    if max_length < 0:
        return input
    if len(input) <= max_length:
        return input
    if manner == "middle":
        split = max_length // 2
        return input[0:split] + input[-split:]
    else:
        return None

def truncate_by_tokens(input, tok, max_tokens, manner: str = "middle"):
    tokens = tok.encode(input)
    len_before = len(tokens)
    print(f"# tokens before: {len_before}")
    tokens = truncate_input(tokens, max_length=max_tokens, manner=manner)
    len_after = len(tokens)  # type: ignore
    print(f"# tokens after: {len_after}")
    assert len_after <= len_before
    assert len_after <= max_tokens or max_tokens < 0
    return tokens

def get_pred(
    model,
    tok: AutoTokenizer,
    input_text: str,
    max_input_length: int,
    verbose: bool = False,
    generation_config: GenerationConfig = None,
    attn_type: str = 'vllm'
) -> str:
    """
    Truncate down to 128k then make inference.
    """
    input_tokens = truncate_by_tokens(input_text, tok, max_input_length)
    if verbose:
        print("# tokens:", len(input_tokens))
        print("=============== Input ===============")
        print(tok.decode(input_tokens[:200]))
        print("...")
        print(tok.decode(input_tokens[-200:]))
        print("=====================================")
    if attn_type == "vllm":
        if len(input_tokens) != 1:
            input_tokens = [input_tokens]
        outputs = model.generate(
            prompt_token_ids=input_tokens,
            sampling_params=generation_config,
        )
        output = outputs[0].outputs[0].text
        output = output.strip()
    else:
        input_tensors = {"input_ids": torch.tensor(input_tokens).unsqueeze(0).to(model.device)}
        # cache = SinkCache(window_length=200000, num_sink_tokens=10000)
        outputs = model.generate(**input_tensors, generation_config=generation_config)
        # outputs = model.generate(**input_tensors, generation_config=generation_config, past_key_values=cache)

        output = outputs[0, len(input_tokens):]
        output = tok.decode(output, skip_special_tokens=True)
        output = output.strip()
    # print(input_text[:5000], input_text[-5000:])
    print("Chunked generation:", output)
    return output

def load_model(
    model_name: str, topk: int=-1, topk_from_layer: int=-1, topk_dims_file_path: str="", use_sparq: bool = False,
    attn_type: str = 'vllm', max_seq_length: int = None
):
    tok = AutoTokenizer.from_pretrained(model_name)
    tok.pad_token = tok.eos_token

    if attn_type == "vllm":
        llm = LLM(
            model_name,
            max_num_seqs=1,
            swap_space=64,
            gpu_memory_utilization=0.98,
            max_model_len=max_seq_length,
        )
    else:
        config = AutoConfig.from_pretrained(model_name)
        if "LWM" in model_name:
            c = {
                'theta': 10000000,
                'max_sequence_length': 131072,
                'scan_attention': True,
                'scan_query_chunk_size': 1024,
                'scan_key_chunk_size': 1024,
                'scan_mlp': True,
                'scan_mlp_chunk_size': 1024,
                'scan_layers': True
            }
            config.update(c)
        if topk != -1:
            config.topk = topk
            config.topk_from_layer = topk_from_layer
        if topk_dims_file_path:
            config.topk_dims_file_path = topk_dims_file_path

        if use_sparq:
            config.topk = 256
            config.local_window = 100
            config.num_top_dim_in_q = 16

        if use_sparq:
            llm = AutoModelForCausalLM.from_pretrained(
                model_name,
                config=config,
                torch_dtype="auto",
                device_map="cuda",
            )
            llm = apply_sparq(llm)
        if attn_type == "hf":
            config._attn_implementation = "eager"
            llm = AutoModelForCausalLM.from_pretrained(
                model_name,
                config=config,
                torch_dtype="auto",
                device_map="cuda",
            )
        elif attn_type == "minference":
            config = AutoConfig.from_pretrained(model_name)

            config._attn_implementation = 'eager'
            config.topk = topk
            config.topk_from_layer = topk_from_layer
            config.topk_dims_file_path = topk_dims_file_path

            config.n_init = 32
            config.n_local = 480
            config.topk_ratio = 0.1
            config.block_size = 32

            llm = LlamaForCausalLM.from_pretrained(
                model_name,
                torch_dtype='auto',
                device_map='auto',
                config=config,
            )
            llm = minference_patch(llm)
        
        elif attn_type == "streaming":
            llm = AutoModelForCausalLM.from_pretrained(
                model_name,
                config=config,
                torch_dtype="auto",
                device_map="cuda",
            )
            llm = patch_hf(llm, attn_type="streaming", attn_kwargs={'n_local': 3968, 'n_init': 128})
        elif attn_type == "flash_attn":
            llm = AutoModelForCausalLM.from_pretrained(
                model_name,
                config=config,
                torch_dtype="auto",
                device_map="cuda",
                attn_implementation="flash_attention_2"
            )
        elif attn_type == "inf_llm":
            llm = AutoModelForCausalLM.from_pretrained(
                model_name,
                config=config,
                torch_dtype="auto",
                device_map="cuda",
            )
            llm = patch_hf(
                llm, attn_type="inf_llm",
                attn_kwargs={
                    'block_size': 128,
                    'n_init': 128,
                    'n_local': 4096,
                    'topk': 16,
                    'repr_topk': 4,
                    'max_cached_block': 32,
                    'exc_block_size': 512,
                    'base': 500000,
                    'distance_scale': 1.0,
                    'perhead': True
                }
            )
        elif attn_type == "snap_kv":
            from snap_kv import replace_llama
            replace_llama()
            llm = AutoModelForCausalLM.from_pretrained(
                model_name,
                config=config,
                torch_dtype="auto",
                device_map="cuda",
                attn_implementation="flash_attention_2"
            )

    print("Model and tokenizer loaded.")
    return llm, tok

if __name__ == "__main__":
    args = parse_args()
    model_name = args.model_name_or_path
    max_seq_length = args.max_seq_length
    real_model_name = model_name.split("/")[-1]
    data_name = args.task

    if ',' in data_name:
        data_names = data_name.split(',')
    else:
        data_names = [data_name]

    # Model
    model, tok = load_model(
        model_name, args.topk, args.topk_from_layer,
        args.topk_dims_file_path, args.use_sparq,
        attn_type=args.attn_type, max_seq_length=max_seq_length
    )
    results = {}

    for data_name in data_names:

        max_new_tokens = LONGBENCH_DATA_NAME_TO_MAX_NEW_TOKENS[data_name]
        if max_new_tokens >= max_seq_length:
            max_new_tokens = 500
        
        if args.attn_type == "vllm":
            generation_config = SamplingParams(
                temperature=0,
                max_tokens=max_new_tokens,
            )
        else:
            generation_config = GenerationConfig(
                max_new_tokens=max_new_tokens,
                num_return_sequences=1,
                do_sample=False,
                # temperature=0,
                # top_p=0.95,
                pad_token_id=tok.pad_token_id,
            )

        # Data
        result_dir = Path(args.output_dir, real_model_name)
        result_dir.mkdir(exist_ok=True, parents=True)
        output_path = result_dir / f"prediction_{data_name}.jsonl"
        examples = load_dataset('THUDM/LongBench', data_name, split='test')

        preds = []
        print("==== Evaluation ====")
        print(f"# examples: {len(examples)}")
        print(f"Num eval examples: {args.num_eval_examples}")
        print(f"Verbose: {args.verbose}")
        print(f"Max new tokens: {max_new_tokens}")

        if os.path.exists(output_path) and not args.rewrite:
            print(f"Output file {output_path} exists. Loading from file.")
            compute_scores(output_path, data_name, real_model_name)

        for i, eg in tqdm(enumerate(examples)):
            if args.num_eval_examples != -1 and i >= args.num_eval_examples:
                break
            input_text = create_longbench_prompt(eg, data_name)
            print(f"====== Example {i} ======")
            pred = get_pred(
                model, tok, input_text,
                max_input_length=max_seq_length-max_new_tokens,
                verbose=args.verbose, generation_config=generation_config,
                attn_type=args.attn_type
            )
            print("Ground Truth", eg["answers"])
            if args.verbose:
                print(pred)
            preds.append(
                {
                    "id": i,
                    "prediction": pred,
                    "ground_truth": eg["answers"],
                    "all_classes": eg["all_classes"],
                }
            )
            dump_jsonl(preds, output_path)
            torch.cuda.empty_cache()
        
        result_file_path = f"{real_model_name}_{args.attn_type}"
        score = compute_scores(output_path, data_name, real_model_name)
        results[data_name] = score

    print("==== Results ====")
    print(json.dumps(results, indent=2))