import os, sys
import time

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
BASE_DIR = os.path.dirname(BASE_DIR)
sys.path.append(BASE_DIR)

import argparse
import pickle
from collections import defaultdict
import torch.types

from utils.utils import set_seed
from utils.utils import get_model
import os
import subprocess

model_custom_config = {
    "max_new_tokens": 1,
    "temperature": 0.1,
    "top_p": 0.9
}

def get_gpu_memory_usage():
    # 执行 nvidia-smi 命令并解析输出
    result = subprocess.run(["nvidia-smi", "--query-gpu=memory.used", "--format=csv,noheader,nounits"],
                            stdout=subprocess.PIPE, text=True)
    # 提取 GPU 内存使用量（以 MB 为单位）
    memory_used = result.stdout
    return memory_used

def main(args):
    if args.cuda == "auto":
        device = "auto"
    else:
        device = torch.device(int(args.cuda))
    tokenizer, model = get_model(args.model_path, device, method=args.method, args=args)

    model.eval()

    token_speed_gpu = defaultdict(dict)


    input_text = " ".join(["test"] * (1000))
    inputs_token = tokenizer(input_text, return_tensors="pt").to(model.device)
    with torch.no_grad():
        # outputs = model(inputs_token.input_ids)
        outputs = model.generate(inputs_token.input_ids, **model_custom_config)


    count = 0
    for token_count in range(args.start, args.end, args.step):
        # 清空GPU缓存，并等待清空完毕
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        time.sleep(5)

        count += 1


        input_text = " ".join(["test"] * (token_count - 1))
        inputs_token = tokenizer(input_text, return_tensors="pt").to(model.device)

        token_length = len(inputs_token.input_ids[0])

        start_gpu_util = get_gpu_memory_usage()
        start_time = time.time()
        with torch.no_grad():
            # outputs = model(inputs_token.input_ids)
            outputs = model.generate(inputs_token.input_ids, **model_custom_config)

        end_time = time.time()
        end_gpu_util = get_gpu_memory_usage()

        dict_value = {
            "time": round(end_time-start_time, 2),
            "start_gpu_util": start_gpu_util,
            "end_gpu_util": end_gpu_util
        }
        token_speed_gpu[token_length] = dict_value
        print("token count: {}, dict-value: {}".format(token_count, dict_value))


        with open(f"{os.path.join(os.getcwd(), args.log_dir)}/{args.save_file}", "wb") as f:
            pickle.dump({"token_speed_gpu": token_speed_gpu}, f)



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="/data/persist/models/llama-3b")
    parser.add_argument("--method", type=str, default="streaming-llm")
    parser.add_argument("--save_file", type=str, default="speed-gpu-streaming-llm-next-token_test.pkl")
    parser.add_argument("--log_dir", type=str, default="../logs")
    parser.add_argument("--cuda", type=str, default="0")
    parser.add_argument("--hard_cuda", type=int, default=1)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--start", type=int, default=1024)
    parser.add_argument("--step", type=int, default=1024)
    parser.add_argument("--end", type=int, default=32*1024)
    args = parser.parse_args()

    set_seed(args.seed)
    main(args)

















