import torch
import transformers
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from tqdm import tqdm

llama_path = '/data/models/huggingface-format/llama-2-7b-chat/'

#tokenizer = AutoTokenizer.from_pretrained(llama_path, use_fast=True)
pipeline = transformers.pipeline(
    "text-generation",
    model=llama_path,
    torch_dtype=torch.float16,
    device_map="auto",
)
pipeline.tokenizer.pad_token_id = pipeline.tokenizer.bos_token_id
pipeline.tokenizer.padding_side = 'left'

def call(message, max_tokens=500):
#    if 1 == len(message):
#        prompt = '<s>[INST] {} [/INST]'.format(message[0]['content'])
#    messages = [{'role': 'user', 'content': message}]
#    else:
    prompt = pipeline.tokenizer.apply_chat_template(
#        messages,
        message,
        tokenize=False,
        add_generation_prompt=True
    )

    with torch.no_grad():
        sequences = pipeline(
            prompt,
            do_sample=True,
            num_return_sequences=1,
            eos_token_id=pipeline.tokenizer.eos_token_id,
            pad_token_id=pipeline.tokenizer.eos_token_id,
            max_new_tokens=max_tokens,
            temperature=0.8,
        )
        response = sequences[0]['generated_text'][len(prompt):]
    
    return response

class TEMP(Dataset):
    def __init__(self, prompts):
        self.prompts = []
        for prompt in prompts:
            self.prompts.append('<s>[INST] {} [/INST]'.format(prompt))

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, idx):
        return self.prompts[idx]

def batch_call(prompts, max_tokens=300):
    tmp_dataset = TEMP(prompts)
    outputs = pipeline(
            tmp_dataset,
            batch_size=16,
            do_sample=True,
            num_return_sequences=1,
            eos_token_id=pipeline.tokenizer.eos_token_id,
            pad_token_id=pipeline.tokenizer.eos_token_id,
            max_new_tokens=max_tokens,
            temperature=0.8,
        )
    responses = []
    for out, data in tqdm(zip(outputs, tmp_dataset), total=len(tmp_dataset)):
        response = out[0]['generated_text'][len(data):].strip()
        responses.append(response)
    return responses

if __name__ == '__main__':
    print('call:\n', call('hello'))
    print('batch_call:\n', batch_call(['hello', 'Mello']))

