import numpy as np
import pickle
import sys
import os

from fidelity_sampler import sample_fidelity

if 'search_space' not in os.environ:
    from nas_bench.cell import Cell
elif os.environ['search_space'] == 'nasbench':
    from nas_bench.cell import Cell
elif os.environ['search_space'] == 'darts':
    from darts.arch import Arch
else:
    from nas_201_api import NASBench201API as API
    from nas_bench_201.cell import Cell


class Data:

    def __init__(self, search_space, mf=False, dataset='cifar10', nasbench_folder='./', loaded_nasbench=None):
        self.search_space = search_space
        self.mf = mf
        self.dataset = dataset
        self.index_hash = pickle.load(open(os.path.expanduser('~/nas_encodings/index_hash.pkl'), 'rb'))


        if search_space == 'nasbench':
            if loaded_nasbench:
                self.nasbench = loaded_nasbench
            else:
                from nasbench import api

                if mf:
                    self.nasbench = api.NASBench(nasbench_folder + 'nasbench_full.tfrecord')
                else:
                    self.nasbench = api.NASBench(nasbench_folder + 'nasbench_only108.tfrecord')

        elif search_space == 'nasbench_201':
            self.nasbench = API(os.path.expanduser('~/nas-bench-201/NAS-Bench-201-v1_0-e61699.pth'))

        elif search_space == 'darts':
            from darts.arch import Arch
        else:
            print(search_space, 'is not a valid search space')
            sys.exit()

    def get_type(self):
        return self.search_space

    def get_mf(self):
        return self.mf

    def epoch_encoding(self, encoding, epochs, change=False):
        """
        Add or change the encoding of an arch to a fidelity (epochs).
        Currently only set up for nasbench space.
        """
        if change:
            encoding = encoding[:-4]

        if epochs == 4:
            encoding = [*encoding, *[1,0,0,0]]
        elif epochs == 12:
            encoding = [*encoding, *[0,1,0,0]]
        elif epochs == 36:
            encoding = [*encoding, *[0,0,1,0]]
        else:
            encoding = [*encoding, *[0,0,0,1]]
        return encoding

    def convert_to_cells(self, 
                            arches, 
                            encoding_type='path',
                            cutoff=40,
                            train=True):
        cells = []
        for arch in arches:
            spec = Cell.convert_to_cell(arch)
            cell = self.query_arch(spec,
                                   encoding_type=encoding_type,
                                   cutoff=cutoff,
                                   train=train)
            cells.append(cell)

        return cells

    def query_arch(self, 
                   arch=None, 
                   train=True, 
                   encoding_type='path', 
                   random='standard',
                   deterministic=True, 
                   epochs=0,
                   cutoff=-1,
                   random_hash=False,
                   max_edges=-1,
                   max_nodes=-1):

        arch_dict = {}
        arch_dict['epochs'] = epochs
        if self.search_space in ['nasbench', 'nasbench_201']:
            if arch is None:
                if max_edges > 0 or max_nodes > 0:
                    arch = Cell.random_cell_constrained(self.nasbench, 
                                                        max_edges=max_edges,
                                                        max_nodes=max_nodes)
                # different random methods
                elif random == 'continuous':
                    arch = Cell.random_cell_continuous(self.nasbench)
                elif random == 'uniform':
                    arch = Cell.random_cell_uniform(self.nasbench)
                elif random == 'path':
                    arch = Cell.random_cell_path(self.nasbench, self.index_hash)
                elif random in ['standard', 'adjacency', 'adj']:
                    arch = Cell.random_cell(self.nasbench)
                elif random in ['adj_freq']:
                    arch = Cell.random_cell(self.nasbench, freq=cutoff)                    
                elif random == 'path_cont':
                    arch = Cell.random_cell_path_cont(self.nasbench, self.index_hash)
                elif random == 'path_freq':
                    arch = Cell.random_cell_path(self.nasbench, self.index_hash, freq=cutoff)
                elif random == 'path_cont_freq':
                    arch = Cell.random_cell_path_cont(self.nasbench, self.index_hash, freq=40)
                else:
                    print('no arch')
            arch_dict['spec'] = arch


            # different encoding methods

            if encoding_type in ['adj', 'adjacency']:
                encoding = Cell(**arch).encode_standard()
            elif encoding_type == 'cat_adj':
                encoding = Cell(**arch).encode_adj_cat()
            elif encoding_type == 'cont_adj':
                encoding = Cell(**arch).encode_continuous()
            elif encoding_type == 'path':
                encoding = Cell(**arch).encode_paths()
            elif encoding_type == 'cat_path':
                indices = Cell(**arch).get_path_indices()
                encoding = tuple([*indices, *[0]*(20-len(indices))])
            elif encoding_type == 'trunc_path':
                encoding = Cell(**arch).encode_freq_paths()
            elif encoding_type == 'trunc_cat_path':
                indices = [i for i in Cell(**arch).get_path_indices() if i < 40]
                encoding = tuple([*indices, *[0]*(20-len(indices))])
            else:
                print('invalid encoding type')

            arch_dict['encoding'] = self.epoch_encoding(encoding, epochs)

            # special keys for local search and outside_ss experiments
            if self.search_space == 'nasbench_201' and random_hash:
                arch_dict['random_hash'] = Cell(**arch).get_random_hash()
            if self.search_space == 'nasbench':
                arch_dict['adjacency'] = Cell(**arch).encode_standard()
                arch_dict['path'] = Cell(**arch).encode_paths()

            if train:
                if not self.get_mf():
                    arch_dict['val_loss'] = Cell(**arch).get_val_loss(self.nasbench, 
                                                                        deterministic=deterministic,
                                                                        dataset=self.dataset)
                    arch_dict['test_loss'] = Cell(**arch).get_test_loss(self.nasbench,
                                                                        dataset=self.dataset)
                else:
                    arch_dict['val_loss'] = Cell(**arch).get_val_loss(self.nasbench, 
                                                                        deterministic=deterministic, 
                                                                        epochs=epochs)
                    arch_dict['test_loss'] = Cell(**arch).get_test_loss(self.nasbench, epochs=epochs)

                arch_dict['num_params'] = Cell(**arch).get_num_params(self.nasbench)
                arch_dict['val_per_param'] = (arch_dict['val_loss'] - 4.8) * (arch_dict['num_params'] ** 0.5) / 100

                if self.search_space == 'nasbench':
                    arch_dict['dist_to_min'] = arch_dict['val_loss'] - 4.94457682
                elif self.dataset == 'cifar10':
                    arch_dict['dist_to_min'] = arch_dict['val_loss'] - 8.3933
                elif self.dataset == 'cifar100':
                    arch_dict['dist_to_min'] = arch_dict['val_loss'] - 26.5067
                else:
                    arch_dict['dist_to_min'] = arch_dict['val_loss'] - 53.2333


        else:
            if arch is None:
                arch = Arch.random_arch()

            if encoding_type == 'path':
                encoding = Arch(arch).encode_paths()
            elif encoding_type == 'path-short':
                encoding = Arch(arch).encode_freq_paths()
            else:
                encoding = arch
            arch_dict['spec'] = arch

            # todo add mf encoding options here
            arch_dict['encoding'] = encoding

            if train:
                if epochs == 0:
                    epochs = 50
                arch_dict['val_loss'], arch_dict['test_loss'] = Arch(arch).query(epochs=epochs)
        
        return arch_dict           

    def mutate_arches(self, arches):
        # method for metann_outside. currently not being used
        mutated = []
        for arch in arches:
            for _ in range(10):
                for e in range(1, 11):
                    mutated = mutate_arch(arch, mutation_rate=e)
                    mutations.append(mutated)

        return mutations    

    def mutate_arch(self, arch, 
                    mutation_rate=1.0, 
                    encoding_type='adjacency', 
                    mutate_type='adj',
                    comparisons=2500,
                    cutoff=-1):
        if self.search_space in ['nasbench', 'nasbench_201']:
            return Cell(**arch).mutate(self.nasbench, 
                                        mutation_rate=mutation_rate, 
                                        mutate_type=mutate_type,
                                        encoding_type=encoding_type,
                                        index_hash=self.index_hash,
                                        cutoff=cutoff)
        else:
            return Arch(arch).mutate(int(mutation_rate))

    def get_nbhd(self, arch, nbhd_type='full'):
        if self.search_space == 'nasbench':
            return Cell(**arch).get_neighborhood(self.nasbench, 
                                                 nbhd_type=nbhd_type,
                                                 index_hash=self.index_hash)
        elif self.search_space == 'nasbench_201':
            return Cell(**arch).get_neighborhood(self.nasbench, nbhd_type=nbhd_type)
        else:
            return Arch(arch).get_neighborhood(nbhd_type=nbhd_type)

    def get_hash(self, arch, epochs=0):
        # return a unique hash of the architecture+fidelity
        # we use path indices + epochs
        if self.search_space == 'nasbench':
            return (*Cell(**arch).get_path_indices(), epochs)
        elif self.search_space == 'darts':
            return (*Arch(arch).get_path_indices()[0], epochs)
        else:
            return Cell(**arch).get_string()

    # todo change kwarg to deterministic
    def generate_random_dataset(self,
                                num=10, 
                                train=True,
                                encoding_type='path', 
                                random='standard',
                                allow_isomorphisms=False, 
                                deterministic_loss=True,
                                patience_factor=5,
                                mf_type=None,
                                cutoff=-1,
                                max_edges=-1,
                                max_nodes=-1):
        """
        create a dataset of randomly sampled architectues
        test for isomorphisms using a hash map of path indices
        use patience_factor to avoid infinite loops
        """
        data = []
        dic = {}
        tries_left = num * patience_factor
        while len(data) < num:
            tries_left -= 1
            if tries_left <= 0:
                break
            epochs = 0
            if mf_type:
                epochs = sample_fidelity(mf_type, query_proportion=0)

            arch_dict = self.query_arch(train=train,
                                        encoding_type=encoding_type,
                                        random=random,
                                        deterministic=deterministic_loss,
                                        epochs=epochs,
                                        cutoff=cutoff,
                                        max_edges=max_edges,
                                        max_nodes=max_nodes)

            h = self.get_hash(arch_dict['spec'], epochs)

            if allow_isomorphisms or h not in dic:
                dic[h] = 1
                data.append(arch_dict)

        return data


    def get_candidates(self, data, 
                       num=100,
                       acq_opt_type='mutation',
                       encoding_type='path',
                       mutate_type='adjacency',
                       loss='val_loss',
                       allow_isomorphisms=False, 
                       patience_factor=5, 
                       deterministic_loss=True,
                       num_arches_to_mutate=1,
                       max_mutation_rate=1,
                       add_data=False,
                       cutoff=-1):
        """
        Creates a set of candidate architectures with mutated and/or random architectures
        """

        candidates = []
        # set up hash map
        dic = {}
        for d in data:
            arch = d['spec']
            h = self.get_hash(arch, 0)
            dic[h] = 1

        if acq_opt_type in ['mutation', 'mutation_random']:
            # mutate architectures with the lowest loss
            best_arches = [arch['spec'] for arch in sorted(data, key=lambda i:i[loss])[:num_arches_to_mutate * patience_factor]]

            # stop when candidates is size num
            # use patience_factor instead of a while loop to avoid long or infinite runtime
            for arch in best_arches:
                if len(candidates) >= num:
                    break
                for i in range(num // num_arches_to_mutate // max_mutation_rate):
                    for rate in range(1, max_mutation_rate + 1):
                        mutated = self.mutate_arch(arch, 
                                                   mutation_rate=rate, 
                                                   mutate_type=mutate_type)
                        arch_dict = self.query_arch(mutated, 
                                                    train=False,
                                                    encoding_type=encoding_type,
                                                    cutoff=cutoff)
                        h = self.get_hash(mutated, 0)

                        if allow_isomorphisms or h not in dic:
                            dic[h] = 1    
                            candidates.append(arch_dict)

        if acq_opt_type in ['random', 'mutation_random']:
            # add randomly sampled architectures to the set of candidates
            for _ in range(num * patience_factor):
                if len(candidates) >= 2 * num:
                    break

                arch_dict = self.query_arch(train=False, 
                                            encoding_type=encoding_type,
                                            cutoff=cutoff)
                h = self.get_hash(arch_dict['spec'], 0)

                if allow_isomorphisms or h not in dic:
                    dic[h] = 1
                    candidates.append(arch_dict)

        return candidates

    def get_next_fidelities(self, data,
                            encoding_type='path',
                            deterministic_loss=True,
                            cutoff=-1):

        def next_fidelity(fidelity):
            if fidelity != 108:
                return 3 * fidelity
            return -1

        dic = {}
        for d in data:
            arch = d['spec']
            h = self.get_hash(arch, 0)
            next_fid = next_fidelity(d['epochs'])
            if next_fid == -1 or h not in dic:
                dic[h] = next_fid
            elif dic[h] != -1:
                dic[h] = max(dic[h], next_fid)

        next_fidelities = []
        for d in data:
            arch = d['spec']
            h = self.get_hash(arch, 0)
            if h in dic and dic[h] != -1:
                epochs = dic.pop(h)
                arch_dict = self.query_arch(d['spec'], 
                                            train=False,
                                            encoding_type=encoding_type,
                                            epochs=epochs,
                                            cutoff=cutoff)

                next_fidelities.append(arch_dict)

        return next_fidelities


    def remove_duplicates(self, candidates, data):
        # input: two sets of architectues: candidates and data
        # output: candidates with arches from data removed

        dic = {}
        for d in data:
            dic[self.get_hash(d['spec'], d['epochs'])] = 1
        unduplicated = []
        for candidate in candidates:
            if self.get_hash(candidate['spec'], candidate['epochs']) not in dic:
                dic[self.get_hash(candidate['spec'], candidate['epochs'])] = 1
                unduplicated.append(candidate)
        return unduplicated

    # todo: this will not be needed once metann_runner is updated
    def encode_data(self, dicts):
        # input: list of arch dictionary objects
        # output: xtrain (in binary path encoding), ytrain (val loss)

        data = []

        for dic in dicts:
            arch = dic['spec']
            encoding = Arch(arch).encode_paths()
            data.append((arch, encoding, dic['val_loss_avg'], None))

        return data

    # Method used for gp_bayesopt
    def get_arch_list(self,
                        aux_file_path, 
                        distance=None, 
                        iteridx=0, 
                        num_top_arches=5,
                        max_edits=20, 
                        num_repeats=5,
                        verbose=1):

        if self.search_space == 'darts':
            print('get_arch_list only supported for nasbench and nasbench_201')
            sys.exit()

        # load the list of architectures chosen by bayesopt so far
        base_arch_list = pickle.load(open(aux_file_path, 'rb'))
        top_arches = [archtuple[0] for archtuple in base_arch_list[:num_top_arches]]
        if verbose:
            top_5_loss = [archtuple[1][0] for archtuple in base_arch_list[:min(5, len(base_arch_list))]]
            print('top 5 val losses {}'.format(top_5_loss))

        # perturb the best k architectures    
        dic = {}
        for archtuple in base_arch_list:
            path_indices = Cell(**archtuple[0]).get_path_indices()
            dic[path_indices] = 1

        new_arch_list = []
        for arch in top_arches:
            for edits in range(1, max_edits):
                for _ in range(num_repeats):
                    #perturbation = Cell(**arch).perturb(self.nasbench, edits)
                    perturbation = Cell(**arch).mutate(self.nasbench, edits)
                    path_indices = Cell(**perturbation).get_path_indices()
                    if path_indices not in dic:
                        dic[path_indices] = 1
                        new_arch_list.append(perturbation)

        # make sure new_arch_list is not empty
        while len(new_arch_list) == 0:
            for _ in range(100):
                arch = Cell.random_cell(self.nasbench)
                path_indices = Cell(**arch).get_path_indices()
                if path_indices not in dic:
                    dic[path_indices] = 1
                    new_arch_list.append(arch)

        return new_arch_list

    # Method used for gp_bayesopt for nasbench
    @classmethod
    def generate_distance_matrix(cls, arches_1, arches_2, distance):
        matrix = np.zeros([len(arches_1), len(arches_2)])
        for i, arch_1 in enumerate(arches_1):
            for j, arch_2 in enumerate(arches_2):
                if distance == 'edit_distance':
                    matrix[i][j] = Cell(**arch_1).edit_distance(Cell(**arch_2))
                elif distance == 'path_distance':
                    matrix[i][j] = Cell(**arch_1).path_distance(Cell(**arch_2))        
                elif distance == 'cont_adj_distance':
                    matrix[i][j] = Cell(**arch_1).cont_adj_distance(Cell(**arch_2))   
                elif distance == 'cont_path_distance':
                    matrix[i][j] = Cell(**arch_1).cont_path_distance(Cell(**arch_2))   
                elif distance == 'nasbot_distance':
                    matrix[i][j] = Cell(**arch_1).nasbot_distance(Cell(**arch_2))  
                elif distance == 'freq_distance': 
                    matrix[i][j] = Cell(**arch_1).freq_distance(Cell(**arch_2))  
                else:
                    print('{} is an invalid distance'.format(distance))
                    sys.exit()
        return matrix
