from misc import N_TOKENS
from evo_optimization import EVO
import experiments.evaluation.instruction_induction.evo_query as evo_query
from args import parse_args
from misc import set_all_seed, TASKS, tkwargs, N_INIT, N_QUERIES, Logger
import datetime
import time
from misc import get_test_conf, get_conf
import re
from automatic_prompt_engineer import evaluate, config, template, data
from transformers import AutoModelForCausalLM, AutoTokenizer
from experiments.evaluation.instruction_induction.exec_accuracy import exec_accuracy_evaluator, exec_evaluator
from experiments.data.instruction_induction.load_data import load_data, load_query_data, save_query_data, load_init_space
from automatic_prompt_engineer import ape, data
import random
import torch
import numpy as np
import copy

import os
cwd = os.getcwd()
os.environ['PATH'] += ':'+cwd


os.environ["TOKENIZERS_PARALLELISM"] = "false"


class LMForwardAPI:
    def __init__(self, model_name=None, eval_data=None, init_prompt=None, init_qa=None, conf=None, base_conf=None,
                 prompt_gen_data=None, intrinsic_dim=None, n_prompt_tokens=None, few_shot_data=None,
                 HF_cache_dir=None, args=None):

        # eval preparation
        self.conf = config.update_config(conf, base_conf)
        self.eval_data = eval_data
        self.eval_template = template.EvalTemplate(
            "Instruction: [PROMPT]\n\nInput: [INPUT]\n Output: [OUTPUT]")
        self.demos_template = template.DemosTemplate(
            "Input: [INPUT]\nOutput: [OUTPUT]")
        self.count = 0

        if args.api_model in ['llama', 'flan-t5']:
            self.api_model = exec_evaluator(args.api_model, self.conf)
        else:
            self.api_model = args.api_model

        # self.approx_model = exec_evaluator('vicuna', self.conf)

        if few_shot_data is None:
            self.few_shot_data = prompt_gen_data

        self.best_train_perf = 0.0
        self.best_dev_perf = 0.0
        self.best_prompt = None
        self.num_call = 0
        self.best_instruction = []
        self.prompts_set = dict()

    def eval_instruct(self, instruction):
        assert isinstance(instruction, list)
        # score of instruction, take in a list of instruction
        print('Instruction: {}'.format(instruction))
        if instruction[0] in self.prompts_set.keys():
            (dev_perf, instruction_score) = self.prompts_set[instruction[0]]
        else:
            if self.api_model in ['chatgpt']:
                dev_perf, instruction_score = evaluate.evaluate_prompts(
                    instruction, self.eval_template, self.eval_data, self.demos_template, self.few_shot_data, self.conf['evaluation']['method'], self.conf['evaluation'])
                dev_perf = dev_perf.sorted()[1][0]
                self.prompts_set[instruction[0]] = (
                    dev_perf, instruction_score)
            else:
                raise NotImplementedError

        if dev_perf > self.best_dev_perf:
            self.best_dev_perf = dev_perf
            self.best_instruction = [instruction]
        elif dev_perf == self.best_dev_perf:
            self.best_instruction.append(instruction)

        self.num_call += 1
        print('STEPS:[{}]. Dev perf: {}. Best dev perf: {}'.format(
            self.num_call,
            round(float(dev_perf), 4),
            round(float(self.best_dev_perf), 4)))
        print('********* Done *********')
        return dev_perf, instruction_score

    def return_best_prompt(self):
        return self.best_instruction[-1]

    def return_prompts_set(self):
        return self.prompts_set


def run_test(test_conf, best_prompt, eval_template, test_data, prompt_gen_data, demos_template, base_conf):
    test_res = ape.evaluate_prompts(prompts=best_prompt,
                                    eval_template=eval_template,
                                    eval_data=test_data,
                                    few_shot_data=prompt_gen_data,
                                    demos_template=demos_template,
                                    conf=test_conf,
                                    base_conf=base_conf)
    test_res = test_res[0]
    best_score = test_res.sorted()[1][0]
    return best_score

def run(args):
    task, HF_cache_dir = args.task, args.HF_cache_dir
    intrinsic_dim, n_prompt_tokens = args.intrinsic_dim, args.n_prompt_tokens
    query_dir = args.query_dir

    assert (args.task in TASKS), 'Task not found!'

    induce_data, test_data = load_data('induce', task), load_data('eval', task)

    # Get size of the induce data
    induce_data_size = len(induce_data[0])
    prompt_gen_size = min(int(induce_data_size * 0.5), 100)
    # Induce data is split into prompt_gen_data and eval_data
    set_all_seed(args.seed)
    prompt_gen_data, eval_data = data.create_split(
        induce_data, prompt_gen_size)

    # Data is in the form input: single item, output: list of items
    # For prompt_gen_data, sample a single item from the output list
    prompt_gen_data = prompt_gen_data[0], [random.sample(output, 1)[0]
                                           for output in prompt_gen_data[1]]

    demos_template = "Input: [INPUT]\nOutput: [OUTPUT]"
    # change the evaluation template
    eval_template = "Instruction: [PROMPT]\n\nInput: [INPUT]\n\nOutput: [OUTPUT]"
    init_prompt = ['\n']
    prompt_gen_template = "[full_DEMO]\n\nThe instruction was to?"

    base_conf = '../experiments/configs/instruction_induction.yaml'
    conf = get_conf(task, eval_data)
    test_conf = get_test_conf(task, test_data)

    # make the demo automatically
    set_all_seed(args.seed)
    subsampled_data = data.subsample_data(
        prompt_gen_data, conf['generation']['num_demos'])
    prompt_gen_template = template.InitQATemplate(prompt_gen_template)
    d_template = template.DemosTemplate(demos_template)
    demos = d_template.fill(subsampled_data)
    init_qa = [prompt_gen_template.fill(demos)]

    model_forward_api = LMForwardAPI(model_name=args.model_name, eval_data=eval_data, init_prompt=init_prompt,
                                     init_qa=init_qa, conf=conf, base_conf=base_conf, prompt_gen_data=prompt_gen_data,
                                     intrinsic_dim=intrinsic_dim, n_prompt_tokens=n_prompt_tokens, HF_cache_dir=HF_cache_dir, args=args)

    try:
        instruct_emb_pairs = load_init_space(args.task, query_dir)
    except:
        print(args.task, ' should be genereated first!')
        breakpoint()

    set_all_seed(args.seed)
    embeddings = random.sample(
        list(instruct_emb_pairs.keys()), N_INIT)  # list of prompts
    embeddings = [torch.tensor(embed) for embed in embeddings]

    initial_prompts = []
    with torch.no_grad():
        for emb in embeddings:
            emb_tuple = tuple(emb.cpu().numpy().tolist())
            instruct = instruct_emb_pairs[emb_tuple]
            dev_score = model_forward_api.eval_instruct(instruct)
            initial_prompts.append((instruct, dev_score[0]))

    population = initial_prompts
    evo_opts = {
        'N': N_INIT,
        'maxiter': N_INIT + N_QUERIES,
        'algo': "GA"
    }

    logger = Logger('logs',
                    "-".join(str(datetime.datetime.now()).split(' ')+[task]),
                    ('seed', args.seed),
                    ('dataset', task),
                    ('query_dir', query_dir),
                    ('intrinsic_dim', intrinsic_dim),
                    ('n_token', n_prompt_tokens),
                    ('inititer', N_INIT),
                    ('maxiter', evo_opts['maxiter']),
                    ('algo', "EVO" + evo_opts['algo'])
                    )

    model = EVO(evo_opts)
    model.k += N_INIT
    model.api = model_forward_api
    N = model.N

    if model.algo == "GA":
        while not model.stop():                
            new_prompts = []
            for i in range(N):
                if model.stop():
                    break
                ## Record
                if model.k % 20 == 0:
                    best_prompt_now = model_forward_api.return_best_prompt()
                    best_score_now=run_test(test_conf,
                                            best_prompt_now,
                                            eval_template,
                                            test_data,
                                            prompt_gen_data,
                                            demos_template,
                                            base_conf)
                    logger.record(model.k, best_prompt_now, best_score_now, model_forward_api.best_dev_perf)
                ## Record
            
                prompt1, prompt2 = model.roulette_wheel_selection(population)
                template_string = "Please follow the instruction step-by-step to generate a better prompt.\n1. [CROSSOVER_PROMPTS]\n2. [MUTATION_PROMPT]"
                evolution_template = template.GAEvolutionTemplate(
                    template_string)
                next_prompt = evo_query.evolve_GA(
                    prompt1, prompt2, evolution_template, model.api.conf['generation'])
                next_score = model_forward_api.eval_instruct(next_prompt)[0]
                logger.log_instruct(next_prompt[0], next_score)
                model.k += 1
                new_prompts.append((next_prompt, next_score))
            population = population + new_prompts
            population.sort(key=lambda x: x[1], reverse=True)
            population = population[:N]

    else:
        while not model.stop():
            new_prompts = []
            for i in range(N):
                if model.stop():
                    break
                prompt1, prompt2 = model.random_selection(
                    population)
                best_index = model.best_prompt(population)
                prompt3 = population[best_index][0]
                prompt4, score = population[i]  # basic prompt
                template_string = "Please follow the instruction step-by-step to generate a better prompt.\n1. [CROSSOVER_PROMPTS]\n2. [MUTATION_PROMPT]\n3. [COMBINE_PROMPT]\n4. [CROSSOVER_FINAL]"
                evolution_template = template.DEEvolutionTemplate(
                    template_string)
                next_prompt = evo_query.evolve_DE(
                    prompt1, prompt2, prompt3, prompt4, evolution_template, model.api.conf['generation'])
                next_score = model_forward_api.eval_instruct(next_prompt)[0]
                logger.log_instruct(next_prompt[0], next_score)
                model.k += 1
                if next_score > score:
                    new_prompts.append((next_prompt, next_score))
                else:
                    new_prompts.append((prompt4, score))
            population = new_prompts

    # Test

    print('Evaluate on test data...')
    best_prompt = model_forward_api.return_best_prompt()
    improved_count = model_forward_api.count
    print("Best instruction is:")
    print(best_prompt)

    # Evaluate on test data
    print('Evaluating on test data...')

    best_score = run_test(test_conf,
                        best_prompt,
                        eval_template,
                        test_data,
                        prompt_gen_data,
                        demos_template,
                        base_conf)

    logger.save(best_prompt, best_score, model_forward_api.best_dev_perf)
    return best_score, improved_count, best_prompt


if __name__ == '__main__':
    args = parse_args()
    # evaluation budget
    print(
        f"Using a total of {N_INIT + N_QUERIES} function evaluations")
    set_all_seed(args.seed)
    test_score, improved_count, prompts = run(args=args)
    print("Finished!!!")
    print(f'Test score on ChatGPT: {test_score}')