from fastchat.model import load_model, get_conversation_template, add_model_args
import torch
import openai
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "4,5"
from vllm import LLM
from vllm.transformers_utils.tokenizer import get_tokenizer


@torch.inference_mode()
def create_model(args, model_path):
    model, tokenizer = load_model(
        model_path,
        "cuda" if args.device.type == "cuda" else "cpu",
        args.gpus,
        args.max_gpu_memory,
        load_8bit = args.load_8bit,
        cpu_offloading = args.cpu_offloading,
        revision = args.revision,
        debug = args.debug,
    )
    return model, tokenizer


def create_model_and_tok(args, model_path, target = False):
    # Note that 'moderation' is only used for classification and cannot be used for generation 
    openai_model_list = ['gpt-3.5-turbo-1106','gpt-3.5-turbo-0613', 'gpt-3.5-turbo', 'gpt-3.5-turbo-0301', 'gpt-4-0613', 'gpt-4', 'gpt-4-0301', 'moderation']
    open_sourced_model_list = ['tiiuae/falcon-7b-instruct', 'tiiuae/falcon-40b', 'lmsys/vicuna-7b-v1.3', 'lmsys/vicuna-33b-v1.3', 'meta-llama/Llama-2-7b-chat-hf', 'lmsys/vicuna-13b-v1.3', 'THUDM/chatglm2-6b', 'meta-llama/Llama-2-13b-chat-hf', 'meta-llama/Llama-2-70b-chat-hf','baichuan-inc/Baichuan-13B-Chat', "mistralai/Mixtral-8x7B-Instruct-v0.1"]
    supported_model_list = openai_model_list + open_sourced_model_list
    if model_path not in supported_model_list:
        print("Please provide a valid model name in the list: {}".format(supported_model_list))
        exit()

    if model_path in openai_model_list:
        if args.openai_key == 'You must have an OpenAI key':
            print("Please provide your OpenAI key or choose an open-sourced model")
            exit()
        else:
            # openai.api_key = args.openai_key   # just for convenience
            print(f"use {model_path}")
            openai.api_key = args.openai_key
            MODEL = model_path
            TOK = None

    elif model_path in open_sourced_model_list:
        if '70b' in model_path or ('falcon' in model_path) \
                or ('vicuna' in model_path) or ('7b' in model_path and target) or ('mistral' in model_path and target):
            parallel_size = torch.cuda.device_count()
            if 'falcon-7b' in model_path:
                parallel_size = 1
            model_args = {
                "model": model_path,
                "gpu_memory_utilization": 0.9,
                "download_dir": download_dir,
                "revision": None,
                "dtype": 'float16',
                "tokenizer": None,
                "tokenizer_mode": 'auto',
                "tokenizer_revision": None,
                "trust_remote_code": False,
                "tensor_parallel_size": parallel_size,
                "swap_space": 4,
                "quantization": None,
                "seed": 1234,
            }
            MODEL = LLM(**model_args)
            TOK = get_tokenizer(
                model_path,
                tokenizer_mode = "auto",
                trust_remote_code = False,
                tokenizer_revision = None,
            )
        else:
            MODEL, TOK = create_model(args, model_path)

    return MODEL, TOK


def prepare_model_and_tok(args, target = False):
    if type(args.model_path) == str:
        MODEL, TOK = create_model_and_tok(args, args.model_path, target = target)
    elif type(args.model_path) == list:
        MODEL, TOK = [], []
        for model_path in args.model_path:
            model, tok = create_model_and_tok(args, model_path)
            MODEL.append(model)
            TOK.append(tok)
    else:
        raise NotImplementedError

    return MODEL, TOK
