import argparse
import random
import os
import json
import time

import tqdm

from utils.config import Config
from utils.logger import create_logger, display_exp_setting
from utils.loader import load_data
from utils.traj import get_trags
from symbolic_compiler.compiler import SymbolicCompiler

parser = argparse.ArgumentParser()
parser.add_argument('--mode', default='testset')
parser.add_argument('--log_dir', default="exp")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--quickrun", action="store_true")

parser.add_argument("--dataset", type=str, default="testset")
parser.add_argument("--dsl", type=str, default="autodsl")

parser.add_argument("--alpha", type=float, default=0.05)
parser.add_argument("--beta", type=float, default=0.45)
parser.add_argument("--gamma", type=float, default=0.50)

# llm
parser.add_argument("--engine", type=str, default="openai/gpt-3.5-turbo")
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--freq_penalty", type=float, default=0.0)
parser.add_argument("--max_tokens", type=int, default=2048)
parser.add_argument("--llm_cache_dir", type=str, default="llm_cache")

args = parser.parse_args()

if __name__ == "__main__":
    # 1. setup config, logger, llm, dataset
    start_time = time.time()
    random.seed(args.seed)
    cfg = Config(args)
    logger = create_logger(os.path.join(cfg.log_dir, 'log.txt'))
    display_exp_setting(logger, cfg)
    test_data = load_data(cfg.dataset, cfg.quickrun)

    if cfg.mode == "symbolic-compiler":
        compiler = SymbolicCompiler(cfg.dataset, cfg.dsl, cfg.engine, cfg.temperature, cfg.freq_penalty, cfg.max_tokens, cfg.llm_cache_dir, cfg.alpha, cfg.beta, cfg.gamma)

        sources, predictions, times, input_regs, reg_flows, trajs, infos = [], [], [], [], [], [], []
        for example in tqdm.tqdm(test_data, total=len(test_data)):
            s_time = time.time()
            prediction, input_reg, reg_flow, traj = compiler.compile(example)
            print(prediction)
            e_time = time.time()

            sources.append(example)
            predictions.append(prediction)
            times.append(str(e_time - s_time))
            input_regs.append(input_reg)
            reg_flows.append(reg_flow)
            trajs.append(traj)

        json_results = {
            "sources": sources,
            "predictions": predictions,
            "times": times,
            "reg_flows": reg_flows,
            "trajs": trajs,
            "input_regs": input_regs,
            "infos": infos
        }
        
        trajs_info = get_trags(trajs)

        with open(f"{cfg.result_dir}/results.json", "w") as f:
            logger.info(f"dumping results to {cfg.result_dir}/results.json")
            json.dump(json_results, f, indent=2)

        logger.info(f"dumping trajs to {cfg.result_dir}/trajs.csv")
        trajs_info.to_csv(f"{cfg.result_dir}/trajs.csv", index=False)

    elif cfg.mode == "testset":
        sources, predictions, times, input_regs, reg_flows, trajs, infos, compare_sents, mappings = [], [], [], [], [], [], [], [], []

        for dataset in ["BioEng", "Ecology", "Genetics", "Medical"]:
            compiler = SymbolicCompiler(dataset, cfg.dsl, cfg.engine, cfg.temperature, cfg.freq_penalty, cfg.max_tokens, cfg.llm_cache_dir, cfg.alpha, cfg.beta, cfg.gamma)
            test_data_subset = [a for a in test_data if a["bigAreas"] == dataset]
            print(dataset)
            for example in tqdm.tqdm(test_data_subset, total=len(test_data_subset)):
                infos.append(json.dumps({"bigAreas":example["bigAreas"], "bigProb":example["bigProb"], "smallProb":example["smallProb"]}))
                example = example["procedures"]
                s_time = time.time()
                prediction, input_reg, reg_flow, traj, compare_sent, mapping = compiler.compile(example)
                print(prediction)
                e_time = time.time()

                sources.append(example)
                predictions.append(prediction)
                times.append(str(e_time - s_time))
                input_regs.append(input_reg)
                reg_flows.append(reg_flow)
                trajs.append(traj)
                compare_sents.append(compare_sent)
                mappings.append(mapping)

        json_results = {
            "sources": sources,
            "predictions": predictions,
            "times": times,
            "reg_flows": reg_flows,
            "trajs": trajs,
            "input_regs": input_regs,
            "infos": infos,
            "compare_sents": compare_sents,
            "mappings": mappings
        }
        
        trajs_info = get_trags(trajs)

        with open(f"{cfg.result_dir}/results.json", "w") as f:
            logger.info(f"dumping results to {cfg.result_dir}/results.json")
            json.dump(json_results, f, indent=2)

        logger.info(f"dumping trajs to {cfg.result_dir}/trajs.csv")
        trajs_info.to_csv(f"{cfg.result_dir}/trajs.csv", index=False)