import argparse
import time
import logging
import itertools
import os
import pickle
import sys
import copy
import numpy as np
from argparse import Namespace

from params import *
from data import Data

#from nas_bench_201.cell import Cell
from nas_bench.cell import Cell

OPS = ['avg_pool_3x3', 'nor_conv_1x1', 'nor_conv_3x3', 'none', 'skip_connect']


def run_local_search(search_space, arch_dict,
                        num_init=1,
                        loss='val_loss',
                        encoding_type='adjacency',
                        cutoff=30,
                        random='standard',
                        nbhd_type_pattern=['full'],
                        stop_at_minimum=True,
                        query_full_nbhd=True,
                        total_queries=10000,
                        allow_isomorphisms=False,
                        deterministic=True,
                        n=20000,
                        verbose=0):

    query_dict = {}
    iter_dict = {}
    query = 0
    data = [] # all arch_dicts
    arch_dicts = [] # arch_dicts that were in an iteration

    query_dict[search_space.get_hash(arch_dict['spec'])] = 1
    data.append(arch_dict)
    query += 1
    nbhd_type_num = -1

    while True:
        # loop over iterations of local search until we hit a local minimum

        if verbose:
            print('starting iteration, query', query)
        iter_dict[search_space.get_hash(arch_dict['spec'])] = 1
        arch_dicts.append(arch_dict)

        # check if we reached the min already
        sorted_data = sorted([(arch, arch[loss]) for arch in data], key=lambda i:i[1])
        if sorted_data[0][0]['dist_to_min'] < 0.00001:
            break

        nbhd_type_num = (nbhd_type_num + 1) % len(nbhd_type_pattern)
        nbhd = search_space.get_nbhd(arch_dict['spec'], 
                                    nbhd_type=nbhd_type_pattern[nbhd_type_num])
        improvement = False
        nbhd_dicts = []
        for nbr in nbhd:
            if search_space.get_hash(nbr) not in query_dict:
                query_dict[search_space.get_hash(nbr)] = 1

                nbr_dict = search_space.query_arch(nbr, 
                                                    encoding_type=encoding_type, 
                                                    cutoff=cutoff,
                                                    deterministic=deterministic)
                data.append(nbr_dict)
                nbhd_dicts.append(nbr_dict)
                query += 1
                if query >= total_queries:
                    break

                if nbr_dict[loss] < arch_dict[loss]:
                    improvement = True

                    if not query_full_nbhd:
                        arch_dict = nbr_dict
                        break

        if query < total_queries and not stop_at_minimum:
            sorted_data = sorted([(arch, arch[loss]) for arch in data], key=lambda i:i[1])
            index = 0
            while index < len(sorted_data) - 1 and \
                    search_space.get_hash(sorted_data[index][0]['spec']) in iter_dict:
                index += 1

            if index == len(sorted_data) - 1 and \
                    search_space.get_hash(sorted_data[index][0]['spec']) in iter_dict:
                break

            arch_dict = sorted_data[index][0]

        elif query < total_queries and improvement:
            sorted_nbhd = sorted([(nbr, nbr[loss]) for nbr in nbhd_dicts], key=lambda i:i[1])
            arch_dict = sorted_nbhd[0][0]

        else:
            # we're stopping at a min or we ran out of queries
            break

    # at the end of each run, gather all the statistics

    sorted_data = sorted([(arch, arch[loss]) for arch in data], key=lambda i:i[1])
    arch_dicts[-1] = sorted_data[0][0]

    if search_space.get_type() == 'nasbench_201':
        arches = [arch_dict['spec']['string'] for arch_dict in arch_dicts]
    else:
        arches = [arch_dict['spec'] for arch_dict in arch_dicts]
    losses = [arch_dict[loss] for arch_dict in arch_dicts]
    test_losses = [arch_dict['test_loss'] for arch_dict in arch_dicts]

    # check for repeats. remove this later if it's working consistently
    dic = {}
    count = 0
    for d in data:
        if search_space.get_hash(d['spec']) not in dic:
            dic[search_space.get_hash(d['spec'])] = 1
        else:
            count += 1
    if count:
        print('there were {} repeats'.format(count))

    return [arches, losses, test_losses, len(data), arch_dicts[-1]['dist_to_min']]


def exhaustive_201(search_space,
                    save_dir='local',
                    out_file='local',
                    num_init=1,
                    loss='val_loss',
                    encoding_type='adjacency',
                    cutoff=30,
                    random='standard',
                    nbhd_type_pattern=['full'],
                    stop_at_minimum=True,
                    query_full_nbhd=True,
                    total_queries=10000,
                    allow_isomorphisms=False,
                    deterministic=True,
                    n=20000,
                    verbose=0):

    arch_paths = []
    loss_paths = []
    test_loss_paths = []
    queries_to_min = []
    dist_to_min = []
    num = 0

    # loop over all arches in nasbench_201
    for ops in itertools.product(OPS, repeat = 6):
        num += 1
        if num % 25 == 0:
            print(num)
        if num % 1000 == 0:
            filename = os.path.join(save_dir, '{}_{}.pkl'.format(out_file, num))
            print('Saving to file {}'.format(filename))
            with open(filename, 'wb') as f:
                pickle.dump([arch_paths, loss_paths, test_loss_paths, queries_to_min], f)
                f.close()
        if num == n:
            break

        # nasbench 201 specific
        arch = Cell.get_string_from_ops(ops)
        arch_dict = search_space.query_arch(arch={'string':arch},
                                            random_hash=(loss=='random_hash'),
                                            deterministic=deterministic)         

        results = run_local_search(search_space, arch_dict,
                                    num_init=num_init,
                                    loss=loss,
                                    encoding_type=encoding_type,
                                    cutoff=cutoff,
                                    random=random,
                                    nbhd_type_pattern=nbhd_type_pattern,
                                    stop_at_minimum=stop_at_minimum,
                                    query_full_nbhd=query_full_nbhd,
                                    total_queries=total_queries,
                                    allow_isomorphisms=allow_isomorphisms,
                                    deterministic=deterministic,
                                    n=n,
                                    verbose=verbose)

        [arches, losses, test_losses, qs, dist] = results

        arch_paths.append(arches)
        loss_paths.append(losses)
        test_loss_paths.append(test_losses)   
        queries_to_min.append(qs)
        dist_to_min.append(dist)

        if verbose:
            print('finished a run of local search')
            print('test losses', test_losses)
            if dist_to_min[-1] == 0:
                print('found the optimal architecture in', queries_to_min[-1], 'queries')
            else:
                print('converged', dist_to_min[-1], 'away from optimal in', queries_to_min[-1], 'queries')

    return [arch_paths, loss_paths, test_loss_paths, queries_to_min]

def spec_hash(spec):
    matrix = spec['matrix']
    ops = spec['ops']
    mat_str = [str(n) for n in matrix.flatten()]
    return ''.join([*ops, *mat_str])

def sample_101(search_space,
                save_dir='local',
                out_file='local',
                num_init=1,
                loss='val_loss',
                encoding_type='adjacency',
                cutoff=40,
                random='standard',
                nbhd_type_pattern=['full'],
                stop_at_minimum=True,
                query_full_nbhd=True,
                total_queries=10000,
                allow_isomorphisms=False,
                deterministic=True,
                n=10000,
                verbose=0):

    arch_paths = []
    loss_paths = []
    test_loss_paths = []
    queries_to_min = []
    dist_to_min = []
    num = 0

    for _ in range(n):
        num += 1
        if num % 200 == 0:
            print(num)
        if num % 5000 == 0:
            filename = os.path.join(save_dir, '{}_{}.pkl'.format(out_file, num))
            print('Saving to file {}'.format(filename))
            with open(filename, 'wb') as f:
                pickle.dump([loss_paths, test_loss_paths, queries_to_min], f)
                f.close()
        if num == n:
            break

        arch_dict = search_space.query_arch(deterministic=deterministic)

        results = run_local_search(search_space, arch_dict,
                                    num_init=num_init,
                                    loss=loss,
                                    encoding_type=encoding_type,
                                    cutoff=cutoff,
                                    random=random,
                                    nbhd_type_pattern=nbhd_type_pattern,
                                    stop_at_minimum=stop_at_minimum,
                                    query_full_nbhd=query_full_nbhd,
                                    total_queries=total_queries,
                                    allow_isomorphisms=allow_isomorphisms,
                                    deterministic=deterministic,
                                    n=n,
                                    verbose=verbose)

        [arches, losses, test_losses, qs, dist] = results

        arch_paths.append(arches)
        loss_paths.append(losses)
        test_loss_paths.append(test_losses)   
        queries_to_min.append(qs)
        dist_to_min.append(dist)

        if verbose:
            print('finished a run')
            print(test_losses)
            print(queries_to_min[-1])
            print(dist_to_min[-1])

    return [loss_paths, test_loss_paths, queries_to_min]

def run_experiments(args, save_dir):

    trials = args.trials
    out_file = args.output_filename
    metann_params = meta_neuralnet_params(args.search_space)
    ls_tree = args.ls_tree

    # set up search space
    mp = copy.deepcopy(metann_params)
    ss = mp.pop('search_space')
    mf = mp.pop('mf')
    dataset = mp.pop('dataset')
    search_space = Data(ss, mf=mf, dataset=dataset)

    for i in range(trials):

        if ss == 'nasbench':
            if ls_tree:
                stop_at_minimum = True
                query_full_nbhd = True
                total_queries = 10000
            else:
                stop_at_minimum = False
                query_full_nbhd = True
                total_queries = 1000

            results = sample_101(search_space, 
                                stop_at_minimum=stop_at_minimum,
                                query_full_nbhd=query_full_nbhd,
                                total_queries=total_queries,
                                save_dir=save_dir,
                                out_file=out_file)

        else: 
            if ls_tree:
                stop_at_minimum = True
                query_full_nbhd = True
                total_queries = 10000

            elif dataset == 'cifar10':
                stop_at_minimum = True
                query_full_nbhd = False
                total_queries = 10000

            else:
                stop_at_minimum = False
                query_full_nbhd = True
                total_queries = 1000

            results = exhaustive_201(search_space,
                                        stop_at_minimum=stop_at_minimum,
                                        query_full_nbhd=query_full_nbhd,
                                        total_queries=total_queries, 
                                        save_dir=save_dir,
                                        out_file=out_file)

        filename = os.path.join(save_dir, '{}_{}.pkl'.format(out_file, i))
        print('\n* Saving to file {}'.format(filename))
        with open(filename, 'wb') as f:
            pickle.dump(results, f)
            f.close()

def main(args):

    # make save directory
    save_dir = args.save_dir
    if not save_dir:
        save_dir = args.algo_params + '/'
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    # set up logging
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO,
        format=log_format, datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(save_dir, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)
    logging.info(args)

    run_experiments(args, save_dir)
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Args for experiments')
    parser.add_argument('--trials', type=int, default=1, help='Number of trials')
    parser.add_argument('--search_space', type=str, default='nasbench_201_cifar10')
    parser.add_argument('--algo_params', type=str, default='local_exps', help='which parameters to use')
    parser.add_argument('--output_filename', type=str, default='round', help='name of output files')
    parser.add_argument('--save_dir', type=str, default=None, help='name of save directory')
    parser.add_argument('--ls_tree', type=bool, default=False, help='construct local search tree')

    args = parser.parse_args()
    main(args)

