import os
import sys
import copy
import json
import numpy as np
from tqdm import tqdm
from openai import OpenAI
from wrench.dataset import load_dataset
from sklearn.metrics import accuracy_score, f1_score
from wrench.endmodel import EndClassifierModel, LogRegModel

dataset_name = sys.argv[1]

EM_dict = {
    "MLP": EndClassifierModel(backbone='MLP')
}

with open(f'./pure_prompting_prediction/{dataset_name}_train.json') as f:
    prediction_data = json.load(f)

train_data, valid_data, test_data = load_dataset(
    "/hdd1/AutoPWS_Data/",
    dataset_name,
    extract_feature=True,
    extract_fn='bert',
    model_name='bert-base-cased',
    cache_name='bert'
)

new_train_data = copy.copy(train_data)

## add pretended weak labels ##
new_train_data.weak_labels = [[-1] for i in range(len(new_train_data))]

for i, (key, small_dict) in enumerate(prediction_data.items()):
    text = small_dict["text"]
    found = False
    for j in range(len(new_train_data.examples)):
        wrench_text = new_train_data.examples[j]["text"]
        if text == wrench_text:
            found = True
            new_train_data.weak_labels[j] = [small_dict["prediction"]]
            new_train_data.labels[j] = small_dict["prediction"]
    if found == False:
        print("uncovered")

covered_train_data = new_train_data.get_covered_subset()

acc_list = []
f1_list = []

for i in range(5):
    for end_model_name, end_model in EM_dict.items():
        end_model.fit(dataset_train=covered_train_data, y_train=covered_train_data.labels,
                      evaluation_step=10, metric='acc', verbose=False, device="cuda:0")
        em_acc = end_model.test(test_data, 'acc')
        if covered_train_data.n_class == 2:
            em_f1 = end_model.test(test_data, 'f1_binary')
        elif covered_train_data.n_class > 2:
            em_f1 = end_model.test(test_data, 'f1_weighted')
        print(f"{i} - {dataset_name} - {end_model_name} - Acc: {em_acc}, F1: {em_f1}")
        acc_list.append(em_acc)
        f1_list.append(em_f1)

acc_array = np.array(acc_list)
f1_array = np.array(f1_list)

acc_mean = np.mean(acc_array, axis=0)
f1_mean = np.mean(f1_array, axis=0)
acc_std = np.std(acc_array, axis=0)
f1_std = np.std(f1_array, axis=0)

print(f"Avg Acc: {acc_mean}, Acc Std: {acc_std}, Avg F1: {f1_mean}, F1 Std: {f1_std}")





      