import os
import re
import json
import logging

from prompt_compiler.earley_parser.parser import EarleyParser
from prompt_compiler.lark_utils import collect_rules_from_larkfile, collect_rules_from_larkstr
from llm_interface.gpt import GPT
from llm_interface.chatgpt import ChatGPT
from prompt_compiler.data_structs.example import Example

logger = logging.getLogger("global_logger")

def load_parser(dataset, use_action_flag_list):
    grammar_file = f"grammars/{dataset}.lark"
    if os.path.exists(grammar_file):
        if dataset == "GeoQuery":
            global_parser = EarleyParser.open(grammar_file, start='query', keep_all_tokens=True)
        elif dataset == "SMCalFlow":
            global_parser = EarleyParser.open(grammar_file, start="call", keep_all_tokens=True)
        elif dataset == "Overnight-Blk":
            global_parser = EarleyParser.open(grammar_file, start="list_value", keep_all_tokens=True)
        else:
            if not use_action_flag_list:
                global_parser = EarleyParser.open(grammar_file, start="program", keep_all_tokens=True)
            else:
                action_list = load_action_list("data/dsl_tag_result/")[dataset.split("_")[0]]
                with open(grammar_file, encoding="utf-8") as f:
                    grammar_str = f.read()
                grammar_str = grammar_str.replace("action_name: ESCAPED_STRING", "action_name: " + " | ".join(["\"\\\"" + a + "\\\"\"" for a in action_list]))
                global_parser = EarleyParser.open_by_str(grammar_str, start="program", keep_all_tokens=True)
    else:
        raise ValueError(f"dataset {dataset} not supported")
    if not use_action_flag_list:
        global_rules, _ = collect_rules_from_larkfile(grammar_file)
    else:
        global_rules, _ = collect_rules_from_larkstr(grammar_str)
    return global_parser, global_rules

def load_llm(engine):
    split_point = engine.index("/")
    platform, engine_short = engine[:split_point], engine[split_point+1:]
    if platform == "openai":
        if engine_short == "code-davinci-002":
            llm = GPT(engine_short, use_azure=False)
        else:
            llm = ChatGPT(engine_short)
    else:
        raise NotImplementedError(f"platform {platform} not supported")
    return llm

def load_examples(filename):
    examples = []
    assert len(filename.split(",")) == 2
    src_filename = filename.split(",")[0]
    trg_filename = filename.split(",")[1]
    with open(src_filename) as f1, open(trg_filename) as f2:
        for line1, line2 in zip(f1, f2):
            examples.append(Example(source=line1.strip(), target=line2.strip(),))
    return examples

def load_train_data(dataset):
    train_filename = f"data/{dataset}/train.src,data/{dataset}/train.tgt"
    train_examples = load_examples(train_filename)
    return train_examples

def load_test_data(quickrun, cut_len):
    with open("data/testset.json", "r") as f:
        test_data = json.load(f)
    
    test_examples = []
    for procedure_id, procedure in enumerate(test_data):
        remake_procedure = re.sub(r'\s+', ' ', procedure["procedures"].replace("\n", " "))
        sentense_list = remake_procedure.split(". ")
        sentense_group = [sentense_list[i:i + cut_len] for i in range(0, len(sentense_list), cut_len)]
        for cut_id, a in enumerate(sentense_group):
            data = Example(source=". ".join(a) + ". ", target="",)
            test_examples.append({
                "cut": data,
                "bigAreas": procedure["bigAreas"],
                "bigProb": procedure["bigProb"],
                "smallProb": procedure["smallProb"],
                "procedure_id": procedure_id,
                "cut_id": cut_id
            })
    if quickrun:
        test_examples = test_examples[:1]
    return test_examples

def load_action_list(path):
    subject_list = {
        "Genetics": "molecular_biology_and_genetics_dsl.json",
        "Medical": "biomedical_and_clinical_research_dsl.json",
        "Ecology": "ecology_and_environmental_environmental_dsl.json",
        "BioEng": "bioengineering_and_technology_dsl.json"
    }
    action_list = {}
    for subject, filename in subject_list.items():
        file_path = path + filename
        with open(file_path, 'r') as f:
            data = json.load(f)
        action_list[subject] = sorted([a.lower() for a in data])
    return action_list