import json
import nltk
import spacy
import gensim
import logging
import random
import re

from itertools import combinations

from nltk.tokenize import sent_tokenize
from nltk.stem import WordNetLemmatizer
from typing import Any, Dict, List

from symbolic_compiler.utils.loader import load_llm
from symbolic_compiler.data_structs.sentence import Sentence
from symbolic_compiler.data_structs.autodsl_sentence import AutoDslSentence
from symbolic_compiler.data_structs.parameter import Parameter
from symbolic_compiler.DSL.autodsl import AutoDsl
from symbolic_compiler.NER.ner import NER
from symbolic_compiler.utils.utils import similar_action

nltk.download('punkt')

logger = logging.getLogger("global_logger")

class SymbolicCompiler:
    def __init__(self, dataset_name, dsl_name, engine, temperature, freq_penalty, max_tokens, llm_cache_dir, alpha, beta, gamma):
        self.dataset_name = dataset_name
        self.dsl_name = dsl_name
        self.engine = engine
        self.temperature = temperature
        self.freq_penalty = freq_penalty
        self.max_tokens = max_tokens
        self.llm_cache_dir = llm_cache_dir

        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma

        self.llm = load_llm(self.engine)
        self.nlp = spacy.load("en_core_web_lg")
        self.word2vec_model = gensim.models.KeyedVectors.load_word2vec_format(
            '../GoogleNews-vectors-negative300.bin.gz', binary=True)
        self.lemmatizer = WordNetLemmatizer()

        if self.dsl_name == 'autodsl':
            with open("data/" + dsl_name + "/" + dataset_name.split("_")[0] + ".json", 'r') as f:
                self.dsl = AutoDsl(json.load(f))
        
        self.NER = NER(self.dsl_name, self.dsl.paramter_list, self.llm, self.temperature, self.freq_penalty, self.max_tokens, self.llm_cache_dir)
        self.sentence_action_ner_results_list = []
        self.results = []
        self.reg_store = []
        self.opt_value = []
        with open("data/output_check_prompt.txt", "r") as f:
            self.output_check_prompt = f.read()

        with open("data/input_check_prompt.txt", "r") as f:
            self.input_check_prompt = f.read()


    def compile(self, source: str) -> str:
        passed_sentences = []
        sentences = self.__text_segmentation(source)
        sentence_action_ner_results_list = []
        for i, sentence in enumerate(sentences):
            logger.info(i)
            if self.dsl_name in "autodsl":
                action, ner_results = self.__autodsl_initwork(sentence)

            if action == None or ner_results == None:
                passed_sentences.append((sentence.text, i))
                continue

            sentence_action_ner_results_list.append((sentence, action, ner_results))

        self.sentence_action_ner_results_list = sentence_action_ner_results_list
        self.results = [None for _ in range(len(sentence_action_ner_results_list))]
        self.reg_store = [("", None) for _ in range(len(sentence_action_ner_results_list))]
        self.opt_value = [(0, 0, 0) for _ in range(len(sentence_action_ner_results_list))]
        try_to_fill = self.optimize_with_constraint(0)
        if try_to_fill:
            logger.info(str(len(self.results)))
            logger.info(str(self.reg_store))
            avg_opt_value = [self.average_of_first_i(i) for i in range(len(sentence_action_ner_results_list))]
            input_reg_store = [[b.value for b in a.parameters if b.property=="reagent"] for a in self.results]
            json_like_results = [self.__convert_json_format(a) for a in self.results]
            
            compare_sent, index_pass, index_trans, mapping = [], 0, 0, {}
            for i in range(len(sentences)):
                if index_pass < len(passed_sentences) and passed_sentences[index_pass][1] == i:
                    compare_sent.append((passed_sentences[index_pass][0], ""))
                    index_pass = index_pass + 1
                else:
                    compare_sent.append((sentence_action_ner_results_list[index_trans][0].text, json.dumps(json_like_results[index_trans])))
                    mapping[str(index_trans)] = str(i)
                    index_trans = index_trans + 1
            return "\n".join([json.dumps(a) for a in json_like_results]), json.dumps(input_reg_store), json.dumps(self.reg_store), json.dumps(avg_opt_value), json.dumps(compare_sent), json.dumps(mapping)
        else:
            return None, None, None, None, None
        
    def __convert_json_format(self, a:AutoDslSentence):
        d = {"action": a.action, "output": a.result}
        for p in a.parameters:
            if p.property not in d:
                d[p.property] = [p.value]
            else:
                d[p.property].append(p.value)
        return d

    def average_of_first_i(self, i):
        sum_values = [0, 0, 0]
        for index in range(i+1):
            sum_values[0] += self.opt_value[index][0]
            sum_values[1] += self.opt_value[index][1]
            sum_values[2] += self.opt_value[index][2]
        avg_values = (sum_values[0] / (i+1), sum_values[1] / (i+1), sum_values[2] / (i+1))
        return avg_values

    def __text_segmentation(self, x: str) -> List[Sentence]:
        sentences = sent_tokenize(x)
        sentences = [x.strip() for x in sentences]
        sentences = [Sentence.create(sentence, self.nlp(sentence), self.nlp) for sentence in sentences]
        sentences = [x for x in sentences if x != None]
        segmentations = [b for a in sentences for b in a.split_sentence()]
        return segmentations

    def __autodsl_initwork(self, x: Sentence):
        try:
            logger.info("sentence: " + x.text)

            action = similar_action(x.action, self.dsl.action_list, self.word2vec_model, self.lemmatizer)
            logger.info(" action: " + action)

            logger.info(" objects: " + str(x.objects))
            ner_results = self.NER.recognition(x)
            logger.info(" NER: " + str(ner_results))
            if len(ner_results) == 0:
                return action, None
            
            return action, ner_results
        except Exception as e:
            logger.info("wrong compiled sentence!")
            return None, None

    def __modify(self, action, text, ner_results:List[Parameter]) -> AutoDslSentence:
        formats = self.dsl.get_format(action=action.upper())
        max_score = 0
        result = {"action":"", "parms":[], "output":""}
        for format in formats:
            result_now = {"action":"", "parms":[], "output":""}
            a, b, c = 0, 0, 0

            for example in format["example"]:
                a = max(a, self.__sentense_similarity(example, text))
            
            revert_format = [self.dsl.label_mapping[a] for a in format["pattern"] if a != "output"]

            need_to_fill = {a:sum([1 if a==b else 0 for b in revert_format]) for a in self.dsl.paramter_list if a != "output"}
            params = []
            count = 0
            for param in ner_results:
                if param.property == "output":
                    result_now["output"] = param.value
                    count = count + 1
                elif need_to_fill[param.property] > 0:
                    need_to_fill[param.property] = need_to_fill[param.property] - 1
                    params.append(param)
                    count = count + 1

            for property in need_to_fill:
                if need_to_fill[property] > 0:
                    for _ in range(need_to_fill[property]):
                        params.append(Parameter(property=property, value=""))
            
            b = count/len(revert_format)
            c = count/len(ner_results)
            result_now["action"] = action
            result_now["parms"] = params
            if self.alpha * a + self.beta * b + self.gamma * c > max_score:
                max_score = self.alpha * a + self.beta * b + self.gamma * c
                
                result = result_now

        if result["action"] == "":
            return None
        return AutoDslSentence(result["action"], result["parms"], result["output"])

    def optimize_with_constraint(self, i):
        if i == len(self.sentence_action_ner_results_list):
            for result in self.results:
                logger.info(str(result))
            return True
        sentence, action, ner_results = self.sentence_action_ner_results_list[i]
        text = sentence.text

        formats = self.dsl.get_format(action=action.upper())
        parm_combinations = self.__get_all_subsets(list(range(len(ner_results))))
        score_list = []
        for format in formats:
            a = 0

            for example in format["example"]:
                a = max(a, self.__sentense_similarity(example, text))
            
            revert_format = [self.dsl.label_mapping[a] for a in format["pattern"] if a != "output"]
            
            for comb in parm_combinations:
                result = {"parms":[], "output":""}
                params = [ner_results[i] for i in comb]
                b, c = 0, 0
                need_to_fill = {a:sum([1 if a==b else 0 for b in revert_format]) for a in self.dsl.paramter_list if a != "output"}
                
                has_output = False
                for param in params:
                    if param.property == "output":
                        if not has_output:
                            result["output"] = param.value
                            b = b + (1/len(revert_format))
                            c = c + (1/len(ner_results))
                            has_output = True
                    elif need_to_fill[param.property] > 0:
                        need_to_fill[param.property] = need_to_fill[param.property] - 1
                        result["parms"].append(param)
                        b = b + (1/len(revert_format))
                        c = c + (1/len(ner_results))
                    elif need_to_fill[param.property] <= 0:
                        need_to_fill[param.property] = need_to_fill[param.property] - 1
                        result["parms"].append(param)
                        b = b - (1/len(revert_format))
                        if b < 0:
                            b = 0
                        c = c - (1/len(ner_results))
                        if c < 0:
                            c = 0

                for property in need_to_fill:
                    if need_to_fill[property] > 0:
                        for _ in range(need_to_fill[property]):
                            result["parms"].append(Parameter(property=property, value=""))

                score_list.append((result, self.alpha * a + self.beta * b + self.gamma * c, a, b*len(ner_results)/(len(ner_results)+1), c))
        
        score_list = sorted(score_list, key=lambda x: x[1], reverse=True)
        
        logger.info(str(i))
        logger.info(str(ner_results))
        for result, score, _, _, _ in score_list[:5]:
            logger.info(str(result["parms"]) + " " + str(result["output"])  + " " + str(score))

        
        bef_sent = None
        if i > 0:
            bef_sent = self.results[i-1]

        for result, _, a, b, c in score_list:
            logger.info("now " + str(i) + " " + str(result))
            now_sent = AutoDslSentence(action, result["parms"], result["output"])
            if result["output"] != "":
                self.reg_store[i] = (result["output"], None)

            if i > 0:
                check_flag, bef_sent_mod, now_sent_mod = self.check_before(i, bef_sent, now_sent)
            else:
                check_flag, now_sent_mod = True, now_sent

            if check_flag:
                if i > 0:
                    self.results[i-1] = bef_sent_mod

                self.results[i] = now_sent_mod
                self.opt_value[i] = (a, b, c)
                try_to_catch = self.optimize_with_constraint(i+1)
                if try_to_catch:
                    return True
            else:
                if random.random() < 0.6:
                    continue
                else:
                    return False
        
        return False


    def __get_all_subsets(self, s):
        subsets = []
        for i in range(len(s) + 1):
            for combo in combinations(s, i):
                subsets.append(list(combo))
        return subsets
    
    def check_before(self, index:int, berfore_sent:AutoDslSentence, sent:AutoDslSentence):
        if index == 0:
            return True

        berfore_sent_mod = None
        sent_mod = sent
        
        sent_input_reg_list = list(set([a.value for a in sent.parameters if a.property == "reagent" and a.value != ""]))
        if not (berfore_sent.result == "" and len(sent_input_reg_list) > 0):
            berfore_sent_mod = berfore_sent
        else:
            if len(sent_input_reg_list) == 1:
                berfore_sent_mod = AutoDslSentence(berfore_sent.action, berfore_sent.parameters, sent_input_reg_list[0])
                self.reg_store[index-1] = (sent_input_reg_list[0], index)
            else:
                query = self.output_check_prompt.replace("[Instruction]", str(berfore_sent)).replace("[Input]", ", ".join(sorted(["\""+a+"\"" for a in sent_input_reg_list])))
                results = self.llm.sample_completions(prompt=query, temperature=self.temperature, freq_penalty=self.freq_penalty, max_tokens=self.max_tokens, llm_cache_dir=self.llm_cache_dir, num_completions=4)
                logger.info("------output check")
                logger.info(query)
                modi_flag = False
                for result in results:
                    response = result.response_text.strip()
                    logger.info(response)
                    matches = re.findall(r'"([^"]*)"', response)
                    if len(matches) > 0 and matches[0] in sent_input_reg_list:
                        modi_flag = True
                        logger.info(matches[0])
                        berfore_sent_mod = AutoDslSentence(berfore_sent.action, berfore_sent.parameters, matches[0])
                        self.reg_store[index-1] = (matches[0], index)
                        break
                if not modi_flag:
                    return False, None, None
                
        sent_input_need_fill = sum([1 if a.property == "reagent" and a.value == "" else 0 for a in sent.parameters])
        now_store = [a[0] for a in self.reg_store[:index] if a[0] != "" and (a[1] == None or a[1] >= index)]
        if not (sent_input_need_fill > 0 and len(now_store) > 0):
            sent_mod = sent
        else:
            query = self.input_check_prompt.replace("[Instruction]", str(sent)).replace("[Input]", ", ".join(sorted(["\""+a+"\"" for a in now_store])))
            results = self.llm.sample_completions(prompt=query, temperature=self.temperature, freq_penalty=self.freq_penalty, max_tokens=self.max_tokens, llm_cache_dir=self.llm_cache_dir, num_completions=4)
            logger.info("------input check")
            logger.info(query)
            for result in results:
                response = result.response_text.strip()
                logger.info(response)
                matches = re.findall(r'"([^"]*)"', response)
                if len(matches) > 0:
                    fill_reg = []
                    now_store_copy = now_store.copy()
                    for match in matches:
                        if match in now_store_copy:
                            fill_reg.append(match)
                            now_store_copy.remove(match)
                    if len(fill_reg) > 0:
                        fill_reg = fill_reg[:sent_input_need_fill]
                        new_param_list = [p for p in sent.parameters if p.property != "reagent" or p.value != ""]
                        for _ in range(sent_input_need_fill - len(fill_reg)):
                            new_param_list.append(Parameter("reagent", ""))
                        now_used_reg_index = []
                        for reg in fill_reg:
                            new_param_list.append(Parameter("reagent", reg))
                            for i, a in enumerate(self.reg_store):
                                if a[0] == reg and (a[1] == None or a[1] >= index) and i not in now_used_reg_index:
                                    now_used_reg_index.append(i)
                                    self.reg_store[i] = (reg, index)

                        sent_mod = AutoDslSentence(sent.action, new_param_list, sent.result)
                        break

        return True, berfore_sent_mod, sent_mod
    
    def __sentense_similarity(self, a, b) -> float:
        doc1 = self.nlp(a)
        doc2 = self.nlp(b)
        similarity = doc1.similarity(doc2)
        return similarity
