import os
import sys
import json
from tqdm import tqdm
from openai import OpenAI, OpenAIError
from wrench.dataset import load_dataset

def calculate_cost(input_tok, output_tok, model="gpt-3.5-turbo"):
    
    pricing_table = {
        "gpt-3.5-turbo": {
            "prompt": 0.0005,
            "completion": 0.0015,
        },
        "gpt-4": {
            "prompt": 0.03,
            "completion": 0.06,
        },
        "claude-2.1": {
            "prompt": 0.008,
            "completion": 0.024,
        },
        "claude-3-sonnet-20240229": {
            "prompt": 0.003,
            "completion": 0.015,
        },
    }

    try:
        model_pricing = pricing_table[model]
    except KeyError:
        raise ValueError("Invalid model specified")

    prompt_cost = input_tok * model_pricing['prompt'] / 1000
    completion_cost = output_tok * model_pricing['completion'] / 1000
    estimated_cost = round(prompt_cost + completion_cost, 6)
    
    return estimated_cost
        
def gpt_inference(system_prompt, user_prompt, model="gpt-3.5-turbo"):

    try:
        client = OpenAI(api_key = OPENAI_API_KEY)
        completion = client.chat.completions.create(
            model=model,
            temperature=0.7,
            max_tokens=100,
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}
            ]
        )
    
        input_tok = completion.usage.prompt_tokens
        output_tok = completion.usage.completion_tokens
        est_cost = calculate_cost(input_tok=input_tok, output_tok=output_tok, model=model)
    
        response = completion.choices[0].message.content
        
        try:
            pred_dict = eval(response)
            prediction = pred_dict.get('class')
            final_prediction = int(class2id_dict[prediction])
            return input_tok, output_tok, est_cost, final_prediction
            
        except Exception as e:
            print(f"Error: {e}")
            return input_tok, output_tok, est_cost, -1
    
    except OpenAIError as e:
        print(f"OpenAI Error: {e}")
        return None, None, None, -1

dataset_name = sys.argv[1]
data_type = sys.argv[2]

### Load Data ###
train_data, valid_data, test_data = load_dataset(
    "/hdd1/AutoPWS_Data/",
    dataset_name,
    extract_feature=True,
    extract_fn='bert',
    model_name='bert-base-cased',
    cache_name='bert'
)

if data_type == "train":
    data = train_data
elif data_type == "valid":
    data = valid_data
else:
    data_type = "test"
    data = test_data

f = open(os.path.join("/hdd1/AutoPWS_Data/", dataset_name, "label.json"))
id2class_dict = json.load(f)
f.close()

class2id_dict = {v: k for k, v in id2class_dict.items()}

### Prompting ###
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

with open("./prompt_template.json", "r") as f:
    prompts = json.load(f)[dataset_name]
     
system_prompt = prompts["prompting_system_prompt"]

prediction_dict = {}

for i in tqdm(range(len(data.examples)), total=len(data.examples)):
    label = data.labels[i]
    text = data.examples[i]["text"]
    user_prompt = prompts["prompting_mission_statement"].replace("[text]", text)

    input_tok, output_tok, cost, prediction = gpt_inference(system_prompt=system_prompt, user_prompt=user_prompt, model="gpt-3.5-turbo")

    if input_tok == None and output_tok == None and cost == None:
        continue
    
    prediction_dict[str(i)] = {
        "label": label,
        "text": text, 
        "input_tok": input_tok,
        "output_tok": output_tok,
        "cost": cost,
        "prediction": prediction
    }

    if i > 0 and (i % 100) == 0:
        with open(f"./pure_prompting_prediction/{dataset_name}_{data_type}.json", "w") as outfile:
            json.dump(prediction_dict, outfile)

