import numpy as np
import copy
import itertools
import random
import sys
import os
import pickle

from nasbench import api


INPUT = 'input'
OUTPUT = 'output'
CONV3X3 = 'conv3x3-bn-relu'
CONV1X1 = 'conv1x1-bn-relu'
MAXPOOL3X3 = 'maxpool3x3'
OPS = [CONV3X3, CONV1X1, MAXPOOL3X3]

NUM_VERTICES = 7
OP_SPOTS = NUM_VERTICES - 2
MAX_EDGES = 9


class Cell:

    def __init__(self, matrix, ops):

        self.matrix = matrix
        self.ops = ops

    def serialize(self):
        return {
            'matrix': self.matrix,
            'ops': self.ops
        }


    def get_utilized(self):
        # return the sets of utilized edges and nodes
        # first, compute all paths
        n = np.shape(self.matrix)[0]
        sub_paths = []
        for j in range(0, n):
            sub_paths.append([[(0, j)]]) if self.matrix[0][j] else sub_paths.append([])
        
        # create paths sequentially
        for i in range(1, n - 1):
            for j in range(1, n):
                if self.matrix[i][j]:
                    for sub_path in sub_paths[i]:
                        sub_paths[j].append([*sub_path, (i, j)])
        paths = sub_paths[-1]

        utilized_edges = []
        for path in paths:
            for edge in path:
                if edge not in utilized_edges:
                    utilized_edges.append(edge)

        utilized_nodes = []
        for i in range(NUM_VERTICES):
            for edge in utilized_edges:
                if i in edge and i not in utilized_nodes:
                    utilized_nodes.append(i)

        return utilized_edges, utilized_nodes

    def num_edges_and_vertices(self):
        # return the true number of edges and vertices
        edges, nodes = self.get_utilized()
        return len(edges), len(nodes)

    def is_valid_vertex(self, vertex):
        edges, nodes = self.get_utilized()
        return (vertex in nodes)

    def is_valid_edge(self, edge):
        edges, nodes = self.get_utilized()
        return (edge in edges)

    def modelspec(self):
        return api.ModelSpec(matrix=self.matrix, ops=self.ops)

    @classmethod
    def convert_to_cell(cls, arch):
        matrix, ops = arch['matrix'], arch['ops']

        if len(matrix) < 7:
            # the nasbench spec can have an adjacency matrix of n x n for n<7, 
            # but in the nasbench api, it is always 7x7 (possibly containing blank rows)
            # so this method will add a blank row/column

            new_matrix = np.zeros((7, 7), dtype='int8')
            new_ops = []
            n = matrix.shape[0]
            for i in range(7):
                for j in range(7):
                    if j < n - 1 and i < n:
                        new_matrix[i][j] = matrix[i][j]
                    elif j == n - 1 and i < n:
                        new_matrix[i][-1] = matrix[i][j]

            for i in range(7):
                if i < n - 1:
                    new_ops.append(ops[i])
                elif i < 6:
                    new_ops.append('conv3x3-bn-relu')
                else:
                    new_ops.append('output')
            return {
                'matrix': new_matrix,
                'ops': new_ops
            }

        else:
            return {
                'matrix': matrix,
                'ops': ops
            }

    @classmethod
    def random_cell_constrained(cls, nasbench, max_edges=10, max_nodes=8):
        # get random cell with edges < max_edges

        while True:
            matrix = np.random.choice(
                [0, 1], size=(NUM_VERTICES, NUM_VERTICES))
            matrix = np.triu(matrix, 1)
            edges, nodes = Cell(matrix=matrix, ops=[]).num_edges_and_vertices()
            if edges <= max_edges and nodes <= max_nodes:
                ops = np.random.choice(OPS, size=NUM_VERTICES).tolist()
                ops[0] = INPUT
                ops[-1] = OUTPUT
                spec = api.ModelSpec(matrix=matrix, ops=ops)
                if nasbench.is_valid(spec):
                    return {
                        'matrix': matrix,
                        'ops': ops
                    }   

    @classmethod
    def random_cell_uniform(cls, nasbench):
        # true uniform random

        hash_list = list(nasbench.hash_iterator())
        n = len(hash_list)
        num = np.random.randint(n)
        unique_hash = hash_list[num]
        fix, _ = nasbench.get_metrics_from_hash(unique_hash)
        cell = {'matrix':fix['module_adjacency'], 'ops':fix['module_operations']}
        return cls.convert_to_cell(cell)

    @classmethod
    def random_cell(cls, nasbench):
        """ 
        From the NASBench repository 

        one-hot adjacency matrix
        draw [0,1] for each slot in the adjacency matrix
        """
        while True:
            matrix = np.random.choice(
                [0, 1], size=(NUM_VERTICES, NUM_VERTICES))
            matrix = np.triu(matrix, 1)
            ops = np.random.choice(OPS, size=NUM_VERTICES).tolist()
            ops[0] = INPUT
            ops[-1] = OUTPUT
            spec = api.ModelSpec(matrix=matrix, ops=ops)
            if nasbench.is_valid(spec):
                return {
                    'matrix': matrix,
                    'ops': ops
                }

    @classmethod
    def random_cell_continuous(cls, nasbench):
        """ 
        continuous adjacency matrix
        draw num_paths randomly
        draw continuous [0,1] for each edge, then threshold
        """
        while True:
            values = np.random.random(size=(NUM_VERTICES, NUM_VERTICES))
            values = np.triu(values, 1)
            n = np.random.randint(8) + 1

            flat = values.flatten()
            threshold = flat[np.argsort(flat)[-1 * n]]

            # now convert it to a model spec
            matrix = np.random.choice([0], size=(NUM_VERTICES, NUM_VERTICES))
            for i in range(NUM_VERTICES):
                for j in range(NUM_VERTICES):
                    if values[i][j] >= threshold:
                        matrix[i][j] = 1

            ops = np.random.choice(OPS, size=NUM_VERTICES).tolist()
            ops[0] = INPUT
            ops[-1] = OUTPUT

            spec = api.ModelSpec(matrix=matrix, ops=ops)
            if nasbench.is_valid(spec):
                return {
                    'matrix': matrix,
                    'ops': ops
                }

    @classmethod
    def random_cell_path(cls, nasbench, index_hash, freq=-1):
        """ 
        one-hot path encoding:
        draw each path with probabilities corresponding to path probs
        """

        while True:

            total_paths = sum([len(OPS) ** i for i in range(OP_SPOTS + 1)])
            probs = [.2, .127, 3.36*10**-2, 3.92*10**-3, 1.5*10**-4, 6.37*10**-7]
            path_indices = []
            n = 0            
            if freq < 0:
                cutoff = total_paths
            else:
                cutoff = min(total_paths, freq)

            # randomly sample paths
            for i in range(OP_SPOTS + 1):
                for j in range(len(OPS)**i):
                    prob = np.random.rand()
                    if prob < probs[i]:
                        path_indices.append(n)
                    n += 1
                    if n >= cutoff:
                        break
                if n >= cutoff:
                    break

            path_indices.sort()
            path_indices = tuple(path_indices)

            if path_indices in index_hash:
                spec = index_hash[path_indices]
                matrix = spec['matrix']
                ops = spec['ops']

                model_spec = api.ModelSpec(matrix, ops)

                if nasbench.is_valid(model_spec):
                    return {
                        'matrix': matrix,
                        'ops': ops
                    }

    @classmethod
    def random_cell_path_cont(cls, nasbench, index_hash, freq=-1, weighted=False):
        """ 
        continuous path encoding:
        draw num_paths randomly
        draw continuous [0,1]*weight for each path, then threshold
        """

        while True:

            total_paths = sum([len(OPS) ** i for i in range(OP_SPOTS + 1)])
            probs = [.2, .127, 3.36*10**-2, 3.92*10**-3, 1.5*10**-4, 6.37*10**-7]
            path_probs = []
            n = 0            
            if freq < 0:
                cutoff = total_paths
            else:
                cutoff = min(total_paths, freq)

            # give continuous value to all paths
            for i in range(OP_SPOTS + 1):
                for j in range(len(OPS)**i):
                    prob = np.random.rand()
                    if weighted:
                        prob *= probs[i]
                    path_probs.append(prob)
                    n += 1
                    if n >= cutoff:
                        break
                if n >= cutoff:
                    break

            # threshold to get num_paths best paths
            num_paths = np.random.choice([i for i in range(1, 7)])
            path_indices = np.argsort(path_probs)[-1 * num_paths:]
            path_indices.sort()
            path_indices = tuple(path_indices)

            # convert to a model spec
            if path_indices in index_hash:
                spec = index_hash[path_indices]
                matrix = spec['matrix']
                ops = spec['ops']

                model_spec = api.ModelSpec(matrix, ops)

                if nasbench.is_valid(model_spec):
                    return {
                        'matrix': matrix,
                        'ops': ops
                    }


    def get_val_loss(self, nasbench, deterministic=1, patience=50, epochs=None, dataset=None):
        if not deterministic:
            # output one of the three validation accuracies at random
            if epochs:
                return (100*(1 - nasbench.query(api.ModelSpec(matrix=self.matrix, ops=self.ops), epochs=epochs)['validation_accuracy']))
            else:
                return (100*(1 - nasbench.query(api.ModelSpec(matrix=self.matrix, ops=self.ops))['validation_accuracy']))
        else:        
            # query the api until we see all three accuracies, then average them
            # a few architectures only have two accuracies, so we use patience to avoid an infinite loop
            accs = []
            while len(accs) < 3 and patience > 0:
                patience -= 1
                if epochs:
                    acc = nasbench.query(api.ModelSpec(matrix=self.matrix, ops=self.ops), epochs=epochs)['validation_accuracy']
                else:
                    acc = nasbench.query(api.ModelSpec(matrix=self.matrix, ops=self.ops))['validation_accuracy']
                if acc not in accs:
                    accs.append(acc)
            return round(100*(1-np.mean(accs)), 4)            


    def get_test_loss(self, nasbench, patience=50, epochs=None, dataset=None):
        """
        query the api until we see all three accuracies, then average them
        a few architectures only have two accuracies, so we use patience to avoid an infinite loop
        """
        accs = []
        while len(accs) < 3 and patience > 0:
            patience -= 1
            if epochs:
                acc = nasbench.query(api.ModelSpec(matrix=self.matrix, ops=self.ops), epochs=epochs)['test_accuracy']
            else:
                acc = nasbench.query(api.ModelSpec(matrix=self.matrix, ops=self.ops))['test_accuracy']
            if acc not in accs:
                accs.append(acc)
        return round(100*(1-np.mean(accs)), 4)

    def get_num_params(self, nasbench):
        return nasbench.query(api.ModelSpec(matrix=self.matrix, ops=self.ops))['trainable_parameters']

    def perturb(self, nasbench, edits=1):
        """ 
        create new perturbed cell 
        inspird by https://github.com/google-research/nasbench
        """
        new_matrix = copy.deepcopy(self.matrix)
        new_ops = copy.deepcopy(self.ops)
        for _ in range(edits):
            while True:
                if np.random.random() < 0.5:
                    for src in range(0, NUM_VERTICES - 1):
                        for dst in range(src+1, NUM_VERTICES):
                            new_matrix[src][dst] = 1 - new_matrix[src][dst]
                else:
                    for ind in range(1, NUM_VERTICES - 1):
                        available = [op for op in OPS if op != new_ops[ind]]
                        new_ops[ind] = np.random.choice(available)

                new_spec = api.ModelSpec(new_matrix, new_ops)
                if nasbench.is_valid(new_spec):
                    break
        return {
            'matrix': new_matrix,
            'ops': new_ops
        }

    def mutate(self, nasbench, 
                mutation_rate=1.0, 
                encoding_type='adjacency', 
                mutate_type='adjacency',
                cutoff=-1, 
                comparisons=2500,
                patience=5000,
                prob_wt=False,
                index_hash=None):
        """
        similar to perturb. A stochastic approach to perturbing the cell
        inspird by https://github.com/google-research/nasbench
        """
        p = 0
        if mutate_type in ['adj', 'adjacency', 'continuous']:
            # ontinuous doesn't get separated from adjacency, for perturbations
            while p < patience:
                p += 1
                new_matrix = copy.deepcopy(self.matrix)
                new_ops = copy.deepcopy(self.ops)

                edge_mutation_prob = mutation_rate / (NUM_VERTICES * (NUM_VERTICES - 1) / 2)
                # flip each edge w.p. so expected flips is 1. same for ops
                for src in range(0, NUM_VERTICES - 1):
                    for dst in range(src + 1, NUM_VERTICES):
                        if random.random() < edge_mutation_prob:
                            new_matrix[src, dst] = 1 - new_matrix[src, dst]

                op_mutation_prob = mutation_rate / OP_SPOTS
                for ind in range(1, OP_SPOTS + 1):
                    if random.random() < op_mutation_prob:
                        available = [o for o in OPS if o != new_ops[ind]]
                        new_ops[ind] = random.choice(available)

                new_spec = api.ModelSpec(new_matrix, new_ops)
                if nasbench.is_valid(new_spec):
                    return {
                        'matrix': new_matrix,
                        'ops': new_ops
                    }
            return self.mutate(nasbench, mutation_rate+1, encoding_type=encoding_type)

        elif mutate_type == 'adj_freq':
            # ontinuous doesn't get separated from adjacency, for perturbations
            while p < patience:
                p += 1
                new_matrix = copy.deepcopy(self.matrix)
                new_ops = copy.deepcopy(self.ops)

                trunc_op_spots = (max(cutoff, 21) - 21) // 2
                if cutoff >= 21 and ((cutoff % 2) == 0):
                    if np.random.rand() > .5:
                        cutoff += 1

                if trunc_op_spots > 0:
                    op_mutation_prob = mutation_rate / trunc_op_spots
                    for ind in range(1, trunc_op_spots + 1):
                        if random.random() < op_mutation_prob:
                            available = [o for o in OPS if o != new_ops[ind]]
                            new_ops[ind] = random.choice(available)


                trunc_edge_spots = max(cutoff, 21)
                if trunc_edge_spots > 0:
                    edge_mutation_prob = mutation_rate / trunc_edge_spots
                    # flip each edge w.p. so expected flips is 1. same for ops
                    n = cutoff
                    for src in range(0, NUM_VERTICES - 1):
                        if n <= 0:
                            break
                        for dst in range(src + 1, NUM_VERTICES):
                            n -= 1
                            if n <= 0:
                                break
                            if random.random() < edge_mutation_prob:
                                new_matrix[src, dst] = 1 - new_matrix[src, dst]

                new_spec = api.ModelSpec(new_matrix, new_ops)
                if nasbench.is_valid(new_spec):
                    return {
                        'matrix': new_matrix,
                        'ops': new_ops
                    }
            return self.mutate(nasbench, mutation_rate+1, encoding_type=encoding_type)

        elif mutate_type == 'adj_cat':
            # flip/add/remove a random edge, change a random op
            while p < patience:
                new_matrix = copy.deepcopy(self.matrix)
                new_ops = copy.deepcopy(self.ops)
                num_edges = np.array(self.matrix).sum()
                diff = np.random.choice([-1, 0, 1])
                triu_indices = []
                for i in range(0, 6):
                    for j in range(i + 1, 7):
                        triu_indices.append((i, j))

                if diff <= 0:
                    # choose a random edge to remove
                    idx = np.random.choice(range(num_edges))
                    counter = 0
                    for (i,j) in triu_indices:
                        if self.matrix[i][j] == 1:
                            if counter == idx:
                                new_matrix[i][j] = 0
                                break
                            else:
                                counter += 1

                if diff >= 0:
                    # choose a random edge to add
                    idx = np.random.choice(range(len(triu_indices) - num_edges))
                    counter = 0
                    for (i,j) in triu_indices:
                        if self.matrix[i][j] == 0:
                            if counter == idx:
                                new_matrix[i][j] = 1
                                break                                
                            else:
                                counter += 1

                op_mutation_prob = mutation_rate / OP_SPOTS
                for ind in range(1, OP_SPOTS + 1):
                    if random.random() < op_mutation_prob:
                        available = [o for o in OPS if o != new_ops[ind]]
                        new_ops[ind] = random.choice(available)

                new_spec = api.ModelSpec(new_matrix, new_ops)
                if nasbench.is_valid(new_spec):
                    return {
                        'matrix': new_matrix,
                        'ops': new_ops
                    }

        elif mutate_type in ['path', 'path_freq']:
            while p < patience:
                p += 1

                path_indices = self.get_path_indices()
                new_path_indices = []

                total_paths = sum([len(OPS) ** i for i in range(OP_SPOTS + 1)])

                n = 0
                if mutate_type == 'path_freq' and cutoff > 0:
                    end = cutoff
                elif mutate_type == 'path_freq':
                    end = 40
                else:
                    end = total_paths

                if prob_wt:
                    probs = [.2, .127, 3.36*10**-2, 3.92*10**-3, 1.5*10**-4, 6.37*10**-7]
                else:
                    probs = [1/end for i in range(6)]

                # randomly sample paths
                for i in range(OP_SPOTS + 1):
                    for j in range(len(OPS)**i):
                        prob = np.random.rand() * mutation_rate
                        if prob < probs[i] and n not in path_indices:
                            new_path_indices.append(n)
                        elif prob > probs[i] and n in path_indices:
                            new_path_indices.append(n)

                        n += 1
                        if n >= end:
                            break
                    if n >= end:
                        break

                # add the paths after cutoff
                for path in path_indices:
                    if path > end:
                        new_path_indices.append(path)

                new_path_indices.sort()
                new_path_indices = tuple(new_path_indices)

                if (new_path_indices is not None) and (new_path_indices in index_hash):
                    spec = index_hash[new_path_indices]
                    matrix = spec['matrix']
                    ops = spec['ops']

                    model_spec = api.ModelSpec(matrix, ops)

                    if nasbench.is_valid(model_spec):
                        return {
                            'matrix': matrix,
                            'ops': ops
                        }

        elif mutate_type in ['path_cat', 'path_cat_freq']:
            while p < patience:
                p += 1

                path_indices = self.get_path_indices()
                new_path_indices = [path for path in path_indices]

                total_paths = sum([len(OPS) ** i for i in range(OP_SPOTS + 1)])
                probs = [.2, .127, 3.36*10**-2, 3.92*10**-3, 1.5*10**-4, 6.37*10**-7]
                diff = np.random.choice([-1, 0, 1])

                n = 0
                if mutate_type == 'path_cat_freq' and cutoff > 0:
                    end = cutoff
                elif mutate_type == 'path_cat_freq':
                    end = 40
                else:
                    end = total_paths

                num_paths = len([i for i in path_indices if i < end])

                # choose a random path to remove
                if diff <= 0:
                    choices = [i for i in range(num_paths)]
                    if len(choices) > 0:
                        idx = np.random.choice(choices)
                        new_path_indices.remove(path_indices[idx])
                    else:
                        diff = 1

                # choose a random path to add
                if diff >= 0:
                    choices = [i for i in range(end) if i not in path_indices]
                    if len(choices) > 0:
                        path = np.random.choice([i for i in range(end) if i not in path_indices])
                        new_path_indices.append(path)
                    elif len(new_path_indices) > 0:
                        new_path_indices.pop(len(new_path_indices)-1)

                new_path_indices.sort()
                new_path_indices = tuple(new_path_indices)

                if (new_path_indices is not None) and (new_path_indices in index_hash):
                    spec = index_hash[new_path_indices]
                    matrix = spec['matrix']
                    ops = spec['ops']

                    model_spec = api.ModelSpec(matrix, ops)

                    if nasbench.is_valid(model_spec):
                        return {
                            'matrix': matrix,
                            'ops': ops
                        }

        return {'matrix': self.matrix, 'ops': self.ops}


    def encode_standard(self):
        """ 
        compute the "standard" encoding,
        i.e. adjacency matrix + op list encoding 
        """
        encoding_length = (NUM_VERTICES ** 2 - NUM_VERTICES) // 2 + OP_SPOTS
        encoding = np.zeros((encoding_length))
        dic = {CONV1X1: 0., CONV3X3: 0.5, MAXPOOL3X3: 1.0}
        n = 0
        for i in range(NUM_VERTICES - 1):
            for j in range(i+1, NUM_VERTICES):
                encoding[n] = self.matrix[i][j]
                n += 1
        for i in range(1, NUM_VERTICES - 1):
            encoding[-i] = dic[self.ops[i]]
        return tuple(encoding)

    def encode_adj_cat(self):
        encoding_length = (NUM_VERTICES ** 2 - NUM_VERTICES) // 2 + OP_SPOTS
        encoding = np.zeros((encoding_length))
        dic = {CONV1X1: 0., CONV3X3: 0.5, MAXPOOL3X3: 1.0}
        n = 0
        m = 0
        for i in range(NUM_VERTICES - 1):
            for j in range(i+1, NUM_VERTICES):
                if self.matrix[i][j]:
                    encoding[m] = n
                    m += 1
                n += 1

        for i in range(1, NUM_VERTICES - 1):
            encoding[-i] = dic[self.ops[i]]
        return tuple(encoding)

    def encode_continuous(self):
        """ 
        compute the continuous encoding from nasbench,
        num in [1,9], adjacency matrix with values in [0,1], and op list
        the edges are the num largest edges in the adjacency matrix
        """
        encoding_length = (NUM_VERTICES ** 2 - NUM_VERTICES) // 2 + OP_SPOTS + 1
        encoding = np.zeros((encoding_length))
        dic = {CONV1X1: 0., CONV3X3: 0.5, MAXPOOL3X3: 1.0}
        n = 0
        for i in range(NUM_VERTICES - 1):
            for j in range(i+1, NUM_VERTICES):
                encoding[n] = self.matrix[i][j]
                n += 1
        for i in range(1, NUM_VERTICES - 1):
            encoding[-i] = dic[self.ops[i]]
        encoding[-1] = self.matrix.sum()
        return tuple(encoding)

    def get_paths(self):
        """ 
        return all paths from input to output
        """
        paths = []
        for j in range(0, NUM_VERTICES):
            paths.append([[]]) if self.matrix[0][j] else paths.append([])
        
        # create paths sequentially
        for i in range(1, NUM_VERTICES - 1):
            for j in range(1, NUM_VERTICES):
                if self.matrix[i][j]:
                    for path in paths[i]:
                        paths[j].append([*path, self.ops[i]])
        return paths[-1]


    def get_path_indices(self):
        """
        compute the index of each path
        There are 3^0 + ... + 3^5 paths total.
        (Paths can be length 0 to 5, and for each path, for each node, there
        are three choices for the operation.)
        """
        paths = self.get_paths()
        mapping = {CONV3X3: 0, CONV1X1: 1, MAXPOOL3X3: 2}
        path_indices = []

        for path in paths:
            index = 0
            for i in range(NUM_VERTICES - 1):
                if i == len(path):
                    path_indices.append(index)
                    break
                else:
                    index += len(OPS) ** i * (mapping[path[i]] + 1)

        path_indices.sort()
        return tuple(path_indices)

    def encode_paths(self):
        """ output one-hot encoding of paths """
        num_paths = sum([len(OPS) ** i for i in range(OP_SPOTS + 1)])
        path_indices = self.get_path_indices()
        encoding = np.zeros(num_paths)
        for index in path_indices:
            encoding[index] = 1
        return encoding

    def encode_freq_paths(self, cutoff=40):
        """ output truncated one-hot encoding of paths """
        return self.encode_paths()[:cutoff]

    def path_distance(self, other):
        """ 
        compute the distance between two architectures
        by comparing their path encodings
        """
        return np.sum(np.array(self.encode_paths() != np.array(other.encode_paths())))


    def path_cont_distance(self, other):
        # wip
        path_indices = [p for p in self.get_path_indices()]
        other_indices = [p for p in other.get_path_indices()]
        path_diffs = np.abs(len(path_indices) - len(other_indices))
        min_paths = min(len(path_indices), len(other_indices))
        total_dist = 0
        print('path', path_indices)
        print('other', other_indices)
        for i in range(min_paths):
            min_dist = 50000
            for path in path_indices:
                for other in other_indices:
                    dist = np.abs(path - other)
                    if dist < min_dist:
                        min_dist = dist
                        min_path = path
                        min_other = other

            print(min_dist, ':', min_path, min_other)
            total_dist += min_dist
            path_indices.remove(min_path)
            other_indices.remove(min_other)
        print(total_dist)
        return (total_dist + 400 * path_diffs) / 100

    def freq_distance(self, other, cutoff=40):
        end = 40
        """ 
        compute the distance between two architectures
        by comparing their path encodings
        """
        encoding = self.encode_paths()[:end]
        other_encoding = other.encode_paths()[:end]
        return np.sum(np.array(encoding) != np.array(other_encoding))

    def edit_distance(self, other):
        """
        compute the distance between two architectures
        by comparing their adjacency matrices and op lists
        """
        graph_dist = np.sum(np.array(self.matrix) != np.array(other.matrix))
        ops_dist = np.sum(np.array(self.ops) != np.array(other.ops))
        return (graph_dist + ops_dist)

    def cont_path_distance(self, other):
        encoding = self.encode_paths()
        other_encoding = other.encode_paths()
        path_dist = np.sum(np.array(encoding) != np.array(other_encoding))
        return path_dist + np.abs(np.sum(encoding) - np.sum(other_encoding))

    def cont_adj_distance(self, other):
        """
        compute the distance between two architectures
        by comparing their adjacency matrices and op lists
        """
        num_edges = np.array(self.matrix).sum()
        other_edges = np.array(self.matrix).sum()
        graph_dist = np.sum(np.array(self.matrix) != np.array(other.matrix))
        ops_dist = np.sum(np.array(self.ops) != np.array(other.ops))
        return (graph_dist + ops_dist + np.abs(num_edges - other_edges))

    def nasbot_distance(self, other):
        # distance based on optimal transport between row sums, column sums, and ops

        row_sums = sorted(np.array(self.matrix).sum(axis=0))
        col_sums = sorted(np.array(self.matrix).sum(axis=1))

        other_row_sums = sorted(np.array(other.matrix).sum(axis=0))
        other_col_sums = sorted(np.array(other.matrix).sum(axis=1))

        row_dist = np.sum(np.abs(np.subtract(row_sums, other_row_sums)))
        col_dist = np.sum(np.abs(np.subtract(col_sums, other_col_sums)))

        counts = [self.ops.count(op) for op in OPS]
        other_counts = [other.ops.count(op) for op in OPS]

        ops_dist = np.sum(np.abs(np.subtract(counts, other_counts)))

        return (row_dist + col_dist + ops_dist)

    def get_neighborhood(self, nasbench, 
                         nbhd_type='op',
                         index_hash=None, 
                         shuffle=True):
        nbhd = []

        if nbhd_type not in ['path', 'path_freq']:
            # add op neighbors
            for vertex in range(1, OP_SPOTS + 1):
                if self.is_valid_vertex(vertex):
                    available = [op for op in OPS if op != self.ops[vertex]]
                    for op in available:
                        new_matrix = copy.deepcopy(self.matrix)
                        new_ops = copy.deepcopy(self.ops)
                        new_ops[vertex] = op
                        new_arch = {'matrix':new_matrix, 'ops':new_ops}
                        nbhd.append(new_arch)

            # add edge neighbors
            for src in range(0, NUM_VERTICES - 1):
                for dst in range(src+1, NUM_VERTICES):
                    new_matrix = copy.deepcopy(self.matrix)
                    new_ops = copy.deepcopy(self.ops)
                    new_matrix[src][dst] = 1 - new_matrix[src][dst]
                    new_arch = {'matrix':new_matrix, 'ops':new_ops}
                
                    if self.matrix[src][dst] and nbhd_type != 'add_edge' \
                    and self.is_valid_edge((src, dst)):
                        spec = api.ModelSpec(matrix=new_matrix, ops=new_ops)
                        if nasbench.is_valid(spec):                            
                            nbhd.append(new_arch)  

                    if not self.matrix[src][dst] and nbhd_type != 'sub_edge' \
                    and Cell(**new_arch).is_valid_edge((src, dst)):
                        spec = api.ModelSpec(matrix=new_matrix, ops=new_ops)
                        if nasbench.is_valid(spec):                            
                            nbhd.append(new_arch)            

        elif nbhd_type in ['path', 'path_freq']:

            path_indices = self.get_path_indices()

            total_paths = sum([len(OPS) ** i for i in range(OP_SPOTS + 1)])
            probs = [.2, .127, 3.36*10**-2, 3.92*10**-3, 1.5*10**-4, 6.37*10**-7]

            n = 0
            if nbhd_type == 'path_freq':
                end = 40   
            else:
                end = total_paths

            new_sets = []
            path_indices_cutoff = [path for path in path_indices if path < end]

            # remove paths
            for path in path_indices_cutoff:
                new_path_indices = [p for p in path_indices if p != path]
                new_sets.append(new_path_indices)

            # add paths
            other_paths = [path for path in range(end) if path not in path_indices]
            for path in other_paths:
                new_path_indices = [*path_indices, path]
                new_sets.append(new_path_indices)

            for new_path_indices in new_sets:
                new_tuple = tuple(new_path_indices)
                if new_tuple in index_hash:

                    spec = index_hash[new_tuple]
                    matrix = spec['matrix']
                    ops = spec['ops']
                    model_spec = api.ModelSpec(matrix=matrix, ops=ops)
                    if nasbench.is_valid(model_spec):                            
                        nbhd.append(spec)
  
        if shuffle:
            random.shuffle(nbhd)
        return nbhd
