import os
import numpy as np
import random
import argparse
import openai
import json
import re
from tqdm import tqdm
import pandas as pd
device = torch.device("cuda")

OPENAI_API = ""#put api key here
BARD_API = ""#put api 

def parse_args():
    """
    Parse the following arguments for a default parser
    """
    parser = argparse.ArgumentParser(
        description="Running experiments"
    )
    parser.add_argument(
        "--d",
        dest="dataset",
        help="which dataset to use",
        default="",
        type=str,
    )
    parser.add_argument(
        "--e",
        dest="experiment",
        help="experiment name",
        default="",
        type=str,
    )
    parser.add_argument(
        "--w",
        dest="optimize",
        help="optimize run",
        default="",
        type=str,
    )
    parser.add_argument(
        "--l",
        dest="length_fix",
        help="fix_folder",
        default="",
        type=str,
    )
    parser.add_argument(
        '--filters', 
        nargs="+", 
        type=int, 
        help="which filters to use, *** MUST ADD IN ORDER ***"
    )
    parser.add_argument(
        "--model",
        choices=['gpt-4', 'gpt-3.5-turbo', 'bard'],
        default="gpt-4",
        help="name of the model",
        type=str,
    )
    return parser.parse_args()


"""
FILTER 1
"""
def filter_qaw(clip_id, question, answer, wrong_answers, clip_result, **kwargs):
    BAD_Q = ["narration", "timestamp", "how much", "how many", "frequency"]
    BAD_A = ["narration", "timestamp", "unclear", "unknown", "not specified"]
    BAD_W = ["narration", "timestamp"]
    
    if None in [question, answer, wrong_answers]:
        return False, {}
    
    if question.strip() == "" or answer.strip() == "":
        return False, {}
    
    if any([w.strip() == "" for w in wrong_answers]):
        return False, {}
    
    for bad_word in BAD_Q:
        if bad_word in question.lower():
            return False, {}
    
    for bad_word in BAD_A:
        if bad_word in answer.lower():
            return False, {}
        
    for bad_word in BAD_W:
        if any([bad_word in w.lower() for w in wrong_answers]):
            return False, {}
    
    return True, {}
     

def no_clip_gpt(clip_id, question, answer, wrong_answers, clip_result, **kwargs):
    
    if "no_clip_gpt_pred" in clip_result:
        return clip_result["no_clip_gpt_pred"] != 4, {"no_clip_gpt_pred": clip_result["no_clip_gpt_pred"]}
    
    for trial in range(3):
        try:
            eval_prompt_copy = gpt_clip_prompt
            eval_prompt_copy += "\n"
            eval_prompt_copy += f"Question: {question}?\n\n"
            eval_prompt_copy += f"Option A: {wrong_answers[0]}.\n"
            eval_prompt_copy += f"Option B: {wrong_answers[1]}.\n"
            eval_prompt_copy += f"Option C: {wrong_answers[2]}.\n"
            eval_prompt_copy += f"Option D: {wrong_answers[3]}.\n"
            eval_prompt_copy += f"Option E: {answer}.\n"

            output = openai.ChatCompletion.create(
                model=model_name,
                messages=[
                    {"role": "system", "content": "You are a helpful assistant."},
                    {"role": "user", "content": eval_prompt_copy}
                ]
            )

            output_text = output["choices"][0]["message"]["content"]
            if output_text[0] == 'A' or output_text[0:8] == 'Option A':
                time.sleep(1)
                return True, {"no_clip_gpt_pred": 0} 
            elif output_text[0] == 'B' or output_text[0:8] == 'Option B':
                time.sleep(1)
                return True, {"no_clip_gpt_pred": 1}
            elif output_text[0] == 'C' or output_text[0:8] == 'Option C':
                time.sleep(1)
                return True, {"no_clip_gpt_pred": 2}
            elif output_text[0] == 'D' or output_text[0:8] == 'Option D':
                time.sleep(1)
                return True, {"no_clip_gpt_pred": 3}
            elif output_text[0] == 'E' or output_text[0:8] == 'Option E':
                time.sleep(1)
                return False, {"no_clip_gpt_pred": 4}
        except Exception as e:
            print(e)
            time.sleep(1)
            print("something wrong with server")
    return True, {"no_clip_gpt_pred": -1}

def no_clip_bard(clip_id, question, answer, wrong_answers, clip_result, **kwargs):
    
    if "no_clip_bard_pred" in clip_result:
        return clip_result["no_clip_bard_pred"] != 4, {"no_clip_bard_pred": clip_result["no_clip_bard_pred"]}
    
    for trial in range(3):
        try:
            eval_prompt_copy = bard_clip_prompt
            eval_prompt_copy += "\n"
            eval_prompt_copy += f"Question: {question}?\n\n"
            eval_prompt_copy += f"Option A: {wrong_answers[0]}.\n"
            eval_prompt_copy += f"Option B: {wrong_answers[1]}.\n"
            eval_prompt_copy += f"Option C: {wrong_answers[2]}.\n"
            eval_prompt_copy += f"Option D: {wrong_answers[3]}.\n"
            eval_prompt_copy += f"Option E: {answer}.\n"

            output = openai.ChatCompletion.create(
                model=model_name,
                messages=[
                    {"role": "system", "content": "You are a helpful assistant."},
                    {"role": "user", "content": eval_prompt_copy}
                ]
            )

            output_text = output["choices"][0]["message"]["content"]
            if output_text[0] == 'A' or output_text[0:8] == 'Option A':
                time.sleep(1)
                return True, {"no_clip_gpt_pred": 0} 
            elif output_text[0] == 'B' or output_text[0:8] == 'Option B':
                time.sleep(1)
                return True, {"no_clip_gpt_pred": 1}
            elif output_text[0] == 'C' or output_text[0:8] == 'Option C':
                time.sleep(1)
                return True, {"no_clip_gpt_pred": 2}
            elif output_text[0] == 'D' or output_text[0:8] == 'Option D':
                time.sleep(1)
                return True, {"no_clip_gpt_pred": 3}
            elif output_text[0] == 'E' or output_text[0:8] == 'Option E':
                time.sleep(1)
                return False, {"no_clip_gpt_pred": 4}
        except Exception as e:
            print(e)
            time.sleep(1)
            print("something wrong with server")
    return True, {"no_clip_gpt_pred": -1}


"""
LIST OF ALL FILTERS CORRESPONDING TO TABLE ABOVE
"""
FILTERS = {
    1: (filter_qaw, True, "Bad Keywords"),
    2: (no_clip_gpt, True, "No clip GPT"),
    3: (no_clip_bard, True, "No clip Bard"),
    }

def main():
    rounds = [0]

    for filter_id in args.filters:
        # For the actual filter
        rounds.append(filter_id)
        total_all.append([0, 0])

    result = []
    
    result_file_name = f"{frame_count}_filter_accuracies"
    for j, filter_id in enumerate(rounds):
        if FILTERS[filter_id][1]:
            result_file_name += "_"
            result_file_name += str(filter_id)
    
    filter_results = {}
    if os.path.isfile(f"{result_folder}/{result_file_name}.json"):
        filter_results_f = open(f"{result_folder}/{result_file_name}.json") 
        filter_results = json.load(filter_results_f)
    print(len(filter_results))
    
    clip_id_to_res = {f"{clip['clip_id']}_{clip['qa_i']}": clip for clip in filter_results}
        
    # Iterate through all clips
    for clip_id in tqdm(qa_data):
        clip = qa_data[clip_id]
        
        clip_q = clip["q"] if "q" in clip else [None, None, None]
        clip_a = clip["a"] if "a" in clip else [None, None, None]
        clip_w = clip["w"] if "w" in clip else [None, None, None]
        
        # Iterate through all three questions per clip
        for i in range(3):
            clip_result = {}
            if f"{clip_id}_{i}" in clip_id_to_res:
                clip_result = clip_id_to_res[f"{clip_id}_{i}"]
                
            question = clip_q[i]
            answer = clip_a[i]
            wrong_answers = clip_w[i]
            
            clip_result["clip_id"] = clip_id
            if "clip_url" in clip:
                clip_result["clip_url"] = clip["clip_url"]
            clip_result["qa_i"] = i
            clip_result["q"] = question
            clip_result["a"] = answer
            clip_result["w"] = wrong_answers
            clip_result["good"] = (good[i] == "good")
            passed_everything = True

            # Iterate through each filter for a question
            for j, filter_id in enumerate(rounds):
                
                filter_fn, true_filter, filter_name = FILTERS[filter_id]    
                filter_result, filter_extra = filter_fn(clip_id, question, answer, wrong_answers, clip_result, qa_id = i)
                clip_result.update(filter_extra)

                if filter_result:
                    # Just counting total that came to the filter
                    total_all[j][1] += 1
                    
                    # Counting the ones that passes
                    total_all[j][0] += 1
                else:
                    # Just counting total that came to the filter
                    total_all[j][1] += 1
                    
                    if true_filter:
                        clip_result["filter_round"] = f"Round {j + 1}, Filter {FILTERS[filter_id][2]}"
                        result.append(clip_result)
                        passed_everything = False
                        break

            # If question passed all filters set the round in which it was filtered out to -1
            if passed_everything:
                clip_result["filter_round"] = -1
                result.append(clip_result)

            with open(f"{result_folder}/{result_file_name}.json", 'w') as f:
                json.dump(result, f)
    
if __name__ == "__main__":
    args = parse_args()
        
    features_folder = f"./features/{args.dataset}_{args.frames}"
    frame_count = args.frames
    
        
    if args.optimize != "":
        experiment_path = f"{args.dataset}_results/{args.experiment}/stage_2/{args.optimize}"
    else:
        experiment_path = f"{args.dataset}_results/{args.experiment}"

    qa_data_path = f"{experiment_path}/all_results.json"
    result_folder = f"{experiment_path}/filter_results"

    if not os.path.exists(result_folder):
        os.mkdir(result_folder)
        
    qa_data_f = open(qa_data_path)
    qa_data = json.load(qa_data_f)
    
    model_name = "gpt-4"
    openai.api_key = OPENAI_API

    palm.configure(api_key=BARD_API)
    
    gpt_clip_f = open("prompts/filtering/gpt_no_clip.txt")
    gpt_clip_prompt = gpt_clip_f.read()

    bard_clip_f = open("prompts/filtering/bard_no_clip.txt")
    bard_clip_prompt = bard_clip_f.read()
    main()
