import os
import json
from typing import List, Dict, Union
from tqdm import tqdm, trange
import time
import numpy as np
import scipy
import scipy.optimize
from transformers import GPT2Tokenizer
from copy import deepcopy
import openai
import random
import os
import asyncio
import anthropic
from copy import deepcopy
from typing import List, Dict, Iterator
import concurrent.futures

MAX_LIMIT = 5
MAX_RETRIES = 20
BATCH_SIZE = 20

openai.api_key = os.environ["OPENAI_API_KEY"]
GPT2TOKENIZER = GPT2Tokenizer.from_pretrained("gpt2")

DEFAULT_MESSAGE = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": None},
]


def chat_gpt_query(**args) -> Union[None, List[str]]:
    """
    A wrapper for openai.ChatCompletion.create() that retries 10 times if it fails.

    Parameters
    ----------
    **args
        The arguments to pass to openai.ChatCompletion.create(). This includes things like the prompt, the model, temperature, etc.

    Returns
    -------
    List[str]
        The list of responses from the API.
    """

    if args.get("messages") is None:
        args["messages"] = deepcopy(DEFAULT_MESSAGE)
        args["messages"][1]["content"] = args["prompt"]
        del args["prompt"]

    if args["model"] == "gpt-4":
        openai.organization = os.environ["SUBSIDIZED_ORG"]
    else:
        openai.organization = os.environ["SUBSIDIZED_ORG"]
    for _ in range(10):
        try:
            responses = openai.ChatCompletion.create(**args)
            all_text_content_responses = [c.message.content for c in responses.choices]
            return all_text_content_responses
        except KeyboardInterrupt:
            raise KeyboardInterrupt
        except Exception as e:
            print(e)
            time.sleep(10)

    return None


def estimate_querying_cost(
    num_prompt_toks: int, num_completion_toks: int, model: str
) -> float:
    """
    Estimate the cost of running the API, as of 2023-04-06.

    Parameters
    ----------
    num_prompt_toks : int
        The number of tokens in the prompt.
    num_completion_toks : int
        The number of tokens in the completion.
    model : str
        The model to be used.

    Returns
    -------
    float
        The estimated cost of running the API.
    """

    if model == "gpt-3.5-turbo":
        cost_per_prompt_token = 0.002 / 1000
        cost_per_completion_token = 0.002 / 1000
    elif model == "gpt-4":
        cost_per_prompt_token = 0.03 / 1000
        cost_per_completion_token = 0.06 / 1000
    elif model == "gpt-4-32k":
        cost_per_prompt_token = 0.06 / 1000
        cost_per_completion_token = 0.12 / 1000
    elif model.startswith("text-davinci-"):
        cost_per_prompt_token = 0.02 / 1000
        cost_per_completion_token = 0.02 / 1000
    else:
        raise ValueError(f"Unknown model: {model}")

    cost = (
        num_prompt_toks * cost_per_prompt_token
        + num_completion_toks * cost_per_completion_token
    )
    return cost


def gpt3wrapper(max_repeat=20, **arguments) -> Union[None, openai.Completion]:
    """
    A wrapper for openai.Completion.create() that retries 20 times if it fails.

    Parameters
    ----------
    max_repeat : int, optional
        The maximum number of times to retry the API call, by default 20
    **arguments
        The arguments to pass to openai.Completion.create(). This includes things like the prompt, the model, temperature, etc.

    Returns
    -------
    Union[None, openai.Completion]
        The response from the API. If the API fails, this will be None.
    """

    openai.organization = os.environ["SUBSIDIZED_ORG"]
    i = 0
    while i < max_repeat:
        try:
            arguments = deepcopy(arguments)
            arguments["prompt"] = arguments["prompts"]
            del arguments["prompts"]
            start_time = time.time()
            response = openai.Completion.create(**arguments)
            end_time = time.time()
            # print('completed one query in', end_time - start_time)
            return response
        except KeyboardInterrupt:
            raise KeyboardInterrupt
        except Exception as e:
            print(arguments["prompt"])
            print(e)
            print("now sleeping")
            time.sleep(30)
            i += 1
    return None


def gpt3wrapper_texts(max_repeat=20, **arguments):
    response = gpt3wrapper(max_repeat=max_repeat, **arguments)
    if response is None:
        return None
    if type(arguments["prompts"]) == list:
        return [r["text"] for r in response["choices"]]
    else:
        return response["choices"][0]["text"]


def gpt3wrapper_texts_batch(max_repeat=20, bsize=20, verbose=False, **arguments):
    prompt = arguments["prompts"]

    if type(prompt) == list:
        return list(
            gpt3wrapper_texts_batch_iter(
                max_repeat=max_repeat, bsize=bsize, verbose=verbose, **arguments
            )
        )
    else:
        assert type(prompt) == str
        return gpt3wrapper_texts(max_repeat=max_repeat, **arguments)


def gpt3wrapper_texts_batch_iter(
    max_repeat=20, bsize=BATCH_SIZE, verbose=False, **arguments
):
    prompt = arguments["prompts"]

    assert type(prompt) == list
    num_batches = (len(prompt) - 1) // bsize + 1
    iterator = trange(num_batches) if verbose else range(num_batches)
    for i in iterator:
        arg_copy = deepcopy(arguments)
        arg_copy["prompts"] = prompt[i * bsize : (i + 1) * bsize]
        response = gpt3wrapper(max_repeat=max_repeat, **arg_copy)
        if response is None:
            for _ in range(len(arg_copy["prompt"])):
                yield None
        else:
            for text in [r["text"] for r in response["choices"]]:
                yield text


def gpt3_query(max_repeat=20, **arguments) -> Union[None, str, List[str]]:
    """
    A wrapper for openai.Completion.create() that returns the text of the response.

    Parameters
    ----------
    max_repeat : int, optional
        The maximum number of times to retry the API call, by default 20
    **arguments
        The arguments to pass to openai.Completion.create(). This includes things like the prompt, the model, temperature, etc.

    Returns
    -------
    Union[None, str, List[str]]
        The text of the response. If the prompt is a list, then the response is a list of strings. Otherwise, it is a single string. If the API call fails, then None is returned.
    """

    # batch the queries
    if type(arguments["prompts"]) == list:
        return gpt3wrapper_texts_batch(max_repeat=max_repeat, **arguments)

    response = gpt3wrapper(max_repeat=max_repeat, **arguments)
    if response is None:
        return None
    else:
        return response["choices"][0]["text"]


def get_context_length(model: str) -> int:
    """
    Get the context length for the given model.

    Parameters
    ----------
    model : str
        The model in the API to be used.

    Returns
    -------
    int
        The context length.
    """

    if model in ("text-davinci-002", "text-davinci-003"):
        return 4096
    if model == "gpt-4":
        return 8000
    elif model == "gpt-4-32k":
        return 32000
    elif model == "gpt-3.5-turbo":
        return 4096
    else:
        raise ValueError(f"Unknown model {model}")


async def query_claude_once(client, **args) -> Dict[str, str]:
    # client = anthropic.Client(os.environ["ANTHROPIC_API_KEY"])
    retries = 0
    while retries < MAX_RETRIES:
        try:
            resp = await client.acompletion(
                prompt=f"{anthropic.HUMAN_PROMPT} {args['prompt']}{anthropic.AI_PROMPT}",
                stop_sequences=[anthropic.HUMAN_PROMPT],
                model=args["model"],
                max_tokens_to_sample=args["max_tokens_to_sample"],
                temperature=args["temperature"],
                top_p=args["top_p"],
            )
            resp["TMP_ID"] = args["TMP_ID"]
            return resp
        except Exception as e:
            retries += 1
            print(f"Error: {e}. Retrying {retries}/{MAX_RETRIES}")

    return {"completion": "", "TMP_ID": args["TMP_ID"]}


DEFAULT_CLAUDE_PARAMETER_DICT = {
    "model": "claude-v1.3",
    "max_tokens_to_sample": 100,
    "temperature": 0.7,
    "top_p": 1.0,
}


async def run_parallel_async_texts_iterator(
    progress_bar=False, max_concurrent=5, **args
) -> Iterator[dict]:
    args_list = []
    for TMP_ID, prompt in enumerate(args["prompts"]):
        arg = deepcopy(DEFAULT_CLAUDE_PARAMETER_DICT)
        arg.update(args)
        arg["prompt"] = prompt
        del arg["prompts"]
        arg["TMP_ID"] = TMP_ID
        args_list.append(arg)
    semaphore = asyncio.Semaphore(max_concurrent)
    tasks = [run_query_with_semaphore(semaphore, **arg) for arg in args_list]
    pbar = tqdm(total=len(args_list), desc="validating") if progress_bar else None

    for future in asyncio.as_completed(tasks):
        result = await future
        if pbar is not None:
            pbar.update(1)
        yield result

    if pbar is not None:
        pbar.close()


async def run_query_with_semaphore(semaphore: asyncio.Semaphore, **arg) -> dict:
    async with semaphore:
        client = anthropic.Client(os.environ["ANTHROPIC_API_KEY"])
        return await query_claude_once(client, **arg)


async def run_parallel_async_dicts_batch(**args) -> List[dict]:
    return [result async for result in run_parallel_async_texts_iterator(**args)]


def run_parallel_sync_dicts_batch(**args) -> List[dict]:
    results = asyncio.run(run_parallel_async_dicts_batch(**args))
    results.sort(key=lambda x: x["TMP_ID"])
    return results


def claude_query(**args) -> List[str]:
    result_dicts = run_parallel_sync_dicts_batch(**args)
    return [result["completion"] for result in result_dicts]


DEFAULT_MESSAGE = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": None},
]


def single_chat_gpt_wrapper(args) -> Union[None, str]:
    if args.get("messages") is None:
        args["messages"] = deepcopy(DEFAULT_MESSAGE)
        args["messages"][1]["content"] = args["prompt"]
        del args["prompt"]

    # openai.organization = os.environ["PAID_ORG"]
    openai.organization = os.environ["SUBSIDIZED_ORG"]
    for _ in range(10):
        try:
            response = openai.ChatCompletion.create(**args)
            text_content_response = response.choices[0].message.content
            return text_content_response
        except KeyboardInterrupt:
            raise KeyboardInterrupt
        except Exception as e:
            print(e)
            time.sleep(10)

    return None


def chat_gpt_wrapper_parallel(
    prompts: List[str], num_processes: int = 1, progress_bar: bool = True, **args
) -> List[str]:
    def update_progress_bar(future):
        if progress_bar:
            pbar.update(1)

    if num_processes == 1:
        results = []
        pbar = tqdm(total=len(prompts), desc="Processing") if progress_bar else None
        for prompt in prompts:
            result = single_chat_gpt_wrapper({**args, "prompt": prompt})
            if progress_bar:
                pbar.update(1)
            results.append(result)
        if progress_bar:
            pbar.close()
        return results

    with concurrent.futures.ProcessPoolExecutor(max_workers=num_processes) as executor:
        tasks = [
            executor.submit(single_chat_gpt_wrapper, {**args, "prompt": prompt})
            for prompt in prompts
        ]
        pbar = tqdm(total=len(tasks), desc="Processing") if progress_bar else None
        for task in concurrent.futures.as_completed(tasks):
            if progress_bar:
                task.add_done_callback(update_progress_bar)
        results = [task.result() for task in tasks]
    if progress_bar:
        pbar.close()
    return results


def query_wrapper(
    prompts: List[str],
    model: str = "text-davinci-003",
    max_tokens: int = 512,
    temperature: float = 0.7,
    top_p: float = 1.0,
    num_processes: int = 1,
) -> List[str]:
    """
    Wrapper for querying all LLM APIs.

    Parameters
    ----------
    prompts : List[str]
        List of prompts to query the model with.
    model : str, optional
        Model to query, by default "text-davinci-003"
    max_tokens : int, optional
        Maximum number of tokens to generate, by default 128
    temperature : float, optional
        Temperature for sampling, by default 0.7
    top_p : float, optional
        Top p for sampling, by default 1.0
    num_processes : int, optional
        Number of processes to use, by default 1

    Returns
    -------
    List[str]
        List of generated texts.
    """
    assert type(prompts) == list
    args = {}
    args["temperature"] = temperature
    args["top_p"] = top_p
    args["prompts"] = prompts

    if model.startswith("claude"):
        args["max_tokens_to_sample"] = max_tokens
        args["model"] = model
        args["max_concurrent"] = num_processes

        return claude_query(**args)
    elif model.startswith("gpt"):
        args["model"] = model
        args["max_tokens"] = max_tokens
        args["num_processes"] = num_processes
        return chat_gpt_wrapper_parallel(**args)
    elif model.startswith("text-davinci"):
        args["engine"] = model
        args["max_tokens"] = max_tokens
        return gpt3_query(**args)


if __name__ == "__main__":
    prompts = ["Obama is a male. Yes or no?"] * 30
    for model_name in ["text-davinci-003"]:  # , "gpt-3.5-turbo", "claude-v1.3"]:
        print(model_name)
        print(
            query_wrapper(
                prompts=prompts,
                model=model_name,
                temperature=1.0,
                top_p=1.0,
                max_tokens=10,
                num_processes=5,
            )
        )
