import time
import os
from typing import Any, Dict, List

import openai
import logging

from llm_interface.large_language_model import LargeLanguageModel
from prompt_compiler.data_structs.llm_response import LLMResponse

logger = logging.getLogger("global_logger")

class GPT(LargeLanguageModel):
    AZURE_API_KEY = os.environ.get("AZURE_KEY")
    OPENAI_API_KEY = os.environ.get("OPENAI_KEY")

    def __init__(self, model_name: str, use_azure=True) -> None:
        self._model_name = model_name
        self.use_azure = use_azure
        if self.use_azure:
            openai.api_key = self.AZURE_API_KEY
            openai.api_base = "https://symdistill.openai.azure.com/"
            openai.api_type = 'azure'
            openai.api_version = '2023-03-15-preview'
        else:
            openai.api_key = self.OPENAI_API_KEY

    def get_id(self) -> str:
        return f"gpt_{self._model_name}"

    def _sample_completions(
            self,
            prompt: str,
            temperature: float,
            stop_token: str,
            max_tokens: int,
            freq_penalty: float,
            num_completions: int = 1) -> List[LLMResponse]:
        response = None
        for _ in range(6):
            try:
                response = openai.Completion.create(
                    engine=self._model_name,
                    prompt=prompt,
                    temperature=temperature,
                    stop=stop_token,
                    max_tokens=max_tokens,
                    frequency_penalty=freq_penalty,
                    n=num_completions)
                # Successfully queried, so break.
                break
            except (openai.error.RateLimitError,
                    openai.error.APIConnectionError, openai.error.APIError):
                # Wait for 60 seconds if this limit is reached. Hopefully rare.
                time.sleep(6)

        if response is None:
            raise RuntimeError("Failed to query OpenAI API.")

        assert len(response["choices"]) == num_completions
        return [
            self._raw_to_llm_response(r, prompt, temperature, stop_token, num_completions)
            for r in response["choices"]
        ]

    def _sample_next_token_with_logit_bias(self, prompt, logit_bias, temperature=0.0):
        response = None
        for _ in range(6):
            try:
                response = openai.Completion.create(
                    engine=self._model_name,
                    prompt=prompt,
                    temperature=0.0,
                    max_tokens=2,
                    logit_bias=logit_bias)
                break
            except (openai.error.RateLimitError,
                    openai.error.APIConnectionError, openai.error.APIError):
                time.sleep(6)
        if response is None:
            raise RuntimeError("Failed to query OpenAI API.")
        return response["choices"][0]["text"]

    @staticmethod
    def _raw_to_llm_response(raw_response: Dict[str, Any],
                             prompt: str,
                             temperature: float,
                             stop_token: str,
                             num_completions: int) -> LLMResponse:
        text = raw_response["text"]

        text = text.strip()
        text = text.replace("<|im_end|>", "")
        text = text.replace("<|im_sep|>", "")

        prompt_info = {
            "temperature": temperature,
            "num_completions": num_completions,
            "stop_token": stop_token,
        }
        return LLMResponse(prompt,
                           text,
                           prompt_info=prompt_info,
                           other_info=raw_response.copy())
