import json
import random
from tqdm import trange
import utils
from d5_problem import D5Problem as Problem
from typing import List, Union
from dataclasses import dataclass
from dataclasses_json import dataclass_json
from query import query_wrapper


PROPOSER_TEMPLATE_NAMES = ["orig", "detailed"]
PROPOSER_TEMPLATE_DICTS = {}
for template_name in PROPOSER_TEMPLATE_NAMES:
    with open(f"templates/gpt_d5_proposer_{template_name}.txt", "r") as f:
        PROPOSER_TEMPLATE_DICTS[template_name] = f.read()


def construct_proposer_prompt(
    text_samples_a: List[str],
    text_samples_b: List[str],
    goal: str,
    example_descriptions: List[str],
    num_descriptions_per_prompt: int,
    template_name: str,
) -> str:
    """
    Construct the prompt for the proposer model.

    Parameters
    ----------
    text_samples_a : List[str]
        A list of text samples for the first category.
    text_samples_b : List[str]
        A list of text samples for the second category.
    goal : str
        The goal or objective the proposer model should follow.
    example_descriptions : List[str], optional
        A list of example descriptions provided for formatting reference.
    num_descriptions_per_prompt : int
        The number of descriptions the model should suggest.
    template_name : str, optional
        The name of the template to use for the prompt

    Returns
    -------
    str
        The formatted prompt for the proposer model.
    """

    assert template_name in PROPOSER_TEMPLATE_NAMES

    text_samples_a = [f"A{i}. {text}" for i, text in enumerate(text_samples_a)]
    text_samples_b = [f"B{i}. {text}" for i, text in enumerate(text_samples_b)]

    samples_in_promp_a = "\n".join(text_samples_a)
    samples_in_prompt_b = "\n".join(text_samples_b)

    example_description_in_prompt = ""
    if len(example_descriptions) > 0:
        example_description_in_prompt = "Here are some example hypotheses you have generated; please generate something in the same format but different in content:\n"
        example_description_in_prompt = (
            "\n"
            + "\n".join(
                f'"{example_description.lower()}"'
                for example_description in example_descriptions
            )
            + "\n"
        )
    template = PROPOSER_TEMPLATE_DICTS[template_name]
    prompt = template.format(
        goal=goal,
        samples_in_prompt_a=samples_in_promp_a,
        samples_in_prompt_b=samples_in_prompt_b,
        example_description_in_prompt=example_description_in_prompt,
        num_descriptions_per_prompt=num_descriptions_per_prompt,
    )
    return prompt


@dataclass_json
@dataclass
class D5ProposerResponse:
    """
    The response from the proposer model.

    Attributes
    ----------
    descriptions : List[str]
        A list of descriptions for the difference between the two Corpora.
    proposer_prompt : str
        The prompt used for the proposer model.
    a_text_subset : List[str]
        The text samples from Corpus A used in the prompt.
    b_text_subset : List[str]
        The text samples from Corpus B used in the prompt.
    estimated_cost : float
        The estimated cost of running the proposer model.
    raw_response : str
        The raw response from the proposer model, before parsing.
    """

    descriptions: List[str]
    proposer_prompt: str
    a_text_subset: List[str]
    b_text_subset: List[str]
    raw_response: str


def propose_descriptions(
    problem: Problem,
    num_samples: int,
    num_descriptions_per_prompt: int,
    model: str,
    random_seed: int = 0,
    example_descriptions: List[str] = [],
    template_name: str = "orig",
) -> D5ProposerResponse:
    """
    Propose descriptions for a given problem.

    Parameters
    ----------
    problem : Problem
        The problem instance.
    num_samples : int
        The number of text samples to be included in the prompt.
    num_descriptions_per_prompt : int
        The number of descriptions the model should suggest.
    model : str
        The model to use for proposing descriptions.
    random_seed : int
        The random seed for sampling text samples.
    example_descriptions : List[str]
        A list of example descriptions provided for formatting reference.
    template_name : str
        The name of the template to use for the prompt, by default "orig"

    Returns
    -------
    D5ProposerResponse
        The response from the proposer model. This includes the descriptions, the prompt, and the text samples used in the prompt.
    """
    # set the random seed
    random.seed(random_seed)

    # get the goal and text samples
    goal = problem.goal
    text_samples_a = random.sample(
        problem.texts_a, min(num_samples, len(problem.texts_a))
    )
    text_samples_b = random.sample(
        problem.texts_b, min(num_samples, len(problem.texts_b))
    )

    # construct the prompt based on the text samples and the goal
    proposer_prompt = construct_proposer_prompt(
        text_samples_a=text_samples_a,
        text_samples_b=text_samples_b,
        goal=goal,
        example_descriptions=example_descriptions,
        num_descriptions_per_prompt=num_descriptions_per_prompt,
        template_name=template_name,
    )

    # get the response from the model
    raw_response = query_wrapper([proposer_prompt], model=model)[0]
    if raw_response is None:
        return None

    # parse the response to get the descriptions
    # each description is separated by a newline, surrounded by quotes according to the prompt
    descriptions = utils.parse_description_responses(raw_response)

    # return the descriptions, the prompt, and the text samples used in the prompt
    return D5ProposerResponse(
        descriptions=descriptions,
        proposer_prompt=proposer_prompt,
        a_text_subset=text_samples_a,
        b_text_subset=text_samples_b,
        raw_response=raw_response,
    )


if __name__ == "__main__":
    problem_path = "data/hh_1.json"
    with open(problem_path, "r") as f:
        problem = Problem.from_json(f.read())

    # texts_a = random.choices(problem.texts_a, k=20)
    # texts_b = random.choices(problem.texts_b, k=20)

    # prompt = construct_proposer_prompt(
    #     text_samples_a=texts_a,
    #     text_samples_b=texts_b,
    #     goal=problem.goal,
    #     example_descriptions=problem.example_descriptions,
    #     num_descriptions_per_prompt=5
    # )

    # print(prompt)

    proposer_response = propose_descriptions(
        problem=problem,
        num_samples=20,
        random_seed=42,
        example_descriptions=problem.example_descriptions,
        num_descriptions_per_prompt=5,
        model="gpt-4",
        template_name="detailed",
    )

    print(proposer_response.descriptions)
