import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
from networkx.generators.nonisomorphic_trees import nonisomorphic_trees


# =========================================
# Dataset creation functions
# =========================================

def glue_trees(tree1, tree2): 
    """
    Glues two trees together to construct a graph as in the lower bound.
    """    
    tree1 = tree1.copy()
    tree2 = tree2.copy()
    
    # pos1  = normalize_pos(graphviz_layout(tree1, prog='twopi', args=''))
    pos1  = normalize_pos(hierarchy_pos(tree1, 0, vert_loc=0.5))
    for node in tree1.nodes:
        tree1.nodes[node]['owner'] = 'alice'
        tree1.nodes[node]['label'] = f'{node+1}'  
        tree1.nodes[node]['color'] = 'b'
        tree1.nodes[node]['pos'] = [pos1[node][1]-0.6, pos1[node][0]]

    # pos2  = normalize_pos(graphviz_layout(tree2, prog='twopi', args=''))
    pos2  = normalize_pos(hierarchy_pos(tree2, 0, vert_loc=0.5))
    for node in tree2.nodes:
        tree2.nodes[node]['owner'] = 'bob'
        tree2.nodes[node]['label'] = f'{node+1}'  
        tree2.nodes[node]['color'] = 'g'
        tree2.nodes[node]['pos'] = [1-pos2[node][1]+0.6, pos2[node][0]]

    graph = nx.algorithms.disjoint_union(tree1, tree2)
    graph.add_edge(list(tree1.nodes)[0], tree1.number_of_nodes()+1)
    
    return graph

def sample_glued_tree(tree_universe):
    
    n_universe = len(tree_universe)

    #  sample two trees
    tree1_idx, tree2_idx = np.random.choice(n_universe, 2)
    tree1 = tree_universe[tree1_idx]
    tree2 = tree_universe[tree2_idx]

    # glue them 
    graph = glue_trees(tree1, tree2)
    
    # randomly permute nodes
    n_tree_nodes = tree1.number_of_nodes()
    p = np.append(np.random.permutation(n_tree_nodes), n_tree_nodes+np.random.permutation(n_tree_nodes))
    A = nx.adjacency_matrix(graph)[p,:][:,p].todense()
    graph_relabeled = nx.from_numpy_matrix(A)
    for node in graph_relabeled.nodes:    
        graph_relabeled.nodes[node]['owner'] = graph.nodes[p[node]]['owner']
        graph_relabeled.nodes[node]['label'] = str(node % n_tree_nodes+1) #graph.nodes[p[node]]['label']
        graph_relabeled.nodes[node]['color'] = graph.nodes[p[node]]['color']
        graph_relabeled.nodes[node]['pos'] = graph.nodes[p[node]]['pos']
        graph_relabeled.nodes[node]['is_root'] = graph.nodes[p[node]]['label']=='1'
 
    # compute the isomorphism class
    isomorphism_class = 0
    if tree1_idx == tree2_idx: 
        isomorphism_class = str(tree1_idx)
    else: 
        if tree1_idx < tree2_idx: 
            isomorphism_class = str(tree1_idx) + str(tree2_idx)
        else:
            isomorphism_class = str(tree2_idx) + str(tree1_idx)
    
    return graph_relabeled, isomorphism_class, tree1_idx, tree2_idx
            
    
# =========================================
# Pytorch functions
# =========================================

def glued_dataset_to_torch(dataset, max_degree=None, unique_ids=False):
    
    import torch
    from torch_geometric.data import Data

    if max_degree is None:         
        degrees = []
        for datum in dataset:
            degrees.extend([deg for _, deg in nx.degree(datum['graph'])])
        max_degree = np.int(max(degrees))+1
        print(f'setting max degree to: {max_degree}')

    dataset_torch = []
    for datum in dataset:

        graph, label = datum['graph'], datum['label']

        edge_index = np.reshape(np.array([([edge[0], edge[1], edge[1], edge[0]]) for edge in nx.to_edgelist(graph)]),(-1,2))
        edge_index = torch.tensor(edge_index.T, dtype=torch.long)

        if unique_ids: 
            x = np.zeros((graph.number_of_nodes(), max_degree+2+graph.number_of_nodes()), dtype=np.float)
        else: 
            x = np.zeros((graph.number_of_nodes(), max_degree+2), dtype=np.float)
            
        for node in graph.nodes:

            # reveal the owner of each node
            if graph.nodes[node]['owner'] == 'alice':
                x[node,0] = 0
            else: 
                x[node,0] = 1

            # reveal the roots
            if graph.nodes[node]['label'] == '1' and graph.nodes[node]['owner'] == 'alice':
                x[node,1] = 1
                
            # reveal the degree (one-hot encoded)
            x[node,2:max_degree+2] = np.eye(max_degree)[nx.degree(graph, node)]
            
            # node ids (for the random label experiment)
            if unique_ids: 
                x[node,max_degree+2:] = np.eye(graph.number_of_nodes())[node]

        x = torch.tensor(x, dtype=torch.float)  

        # graph label
        y = torch.tensor([np.where(label)[0][0]], dtype=torch.long)

        dataset_torch.append(Data(x=x, edge_index=edge_index, edge_attr=None, y=y))
        
    return dataset_torch

# =========================================
# Counting functions 
#  Equations and sequence for number of unlabeled trees taken from https://oeis.org/A000055
# =========================================

A000055 = np.array([1, 1, 1, 2, 3, 6, 11, 23, 47, 106, 235, 551, 1301, 3159, 7741, 19320, 48629, 123867, 317955, 823065, 2144505, 5623756, 14828074, 39299897, 104636890, 279793450, 751065460, 2023443032, 5469566585, 14830871802, 40330829030, 109972410221, 300628862480, 823779631721, 2262366343746, 6226306037178])

def n_glued_trees(n_all, asymptotic=False):
    result = []
    for n in n_all:
                
        # first option has been disabled
        if (not asymptotic) and (n%2 == 0) and (n/2 < len(A000055)):
            t = A000055[np.int(n/2)-1]
        else:
            # Otter's formula for the number of unlabeled trees on n/2 nodes
            t = np.clip(np.floor(0.534949606 * (2.95576528565)**(n/2) * (n/2)**(-5/2)), 1, np.inf)
            
        # account for the fact that there are two sides.         
        result.append((t+1)*t/2)
        
    return np.array(result).astype(np.int)

def n_labeled_glued_trees(n_all, asymptotic=False):
    
    import scipy as sp

    result = []
    for n in n_all:
        
        if (not asymptotic) and (n%2 == 0) and (n/2 < len(A000055)):
            t = A000055[np.int(n/2)-1]
        else:
            # Otter's formula for the number of unlabeled trees on n/2 nodes
            t = np.clip(np.floor(0.534949606 * (2.95576528565)**(n/2) * (n/2)**(-5/2)), 1, np.inf)
            
        # count all permutations
        t *= sp.special.factorial(n/2)
        # account for the fact that there are two sides.         
        result.append(t*t)
    return np.array(result)

# =========================================
# Plotting functions 
# =========================================

def draw_glued_graph_paper(graph, figsize=(6, 5)):
    """
    Draws a glued graph nicely (for paper illustration).
    """
    
    colors = [graph.nodes[i]['color'] for i in graph.nodes]
    for i in graph.nodes: 
        if graph.nodes[i]['owner'] == 'alice': 
            colors[i] = plt.cm.RdYlGn(0.35) #np.array([35, 100, 119])/255
        else: 
            colors[i] = plt.cm.RdYlGn(0.9) #np.array([124, 173, 62])/255
            
    labels = dict(zip(graph.nodes, [graph.nodes[i]['label'] for i in graph.nodes]))
    pos = [graph.nodes[i]['pos'] for i in graph.nodes]

    fig = plt.figure(figsize=figsize, facecolor=[1,1,1])
    ax = fig.add_subplot(1, 1, 1)
    nx.draw_networkx_edges(graph, pos, alpha=0.85, width=1.5, ax=ax)
    nx.draw_networkx_nodes(graph, pos, node_size=200, alpha=1, node_color=colors, linewidths=1.25, edgecolors=[0.12,0.12,0.12], ax=ax)
#     nx.draw_networkx_labels(graph, pos, alpha=1, labels=labels, font_color=[0,0,0], font_size=8)
    plt.axis('equal')
    plt.axis('off')
    return fig


def draw_glued_tree(graph, figsize=(16, 10)):
    """
    Draws a glued tree.
    """
    
    colors = [graph.nodes[i]['color'] for i in graph.nodes]
    labels = dict(zip(graph.nodes, [graph.nodes[i]['label'] for i in graph.nodes]))
    pos = [graph.nodes[i]['pos'] for i in graph.nodes]

    plt.figure(figsize=figsize)
    nx.draw_networkx_edges(graph, pos, alpha=0.8, width=1.5)
    nx.draw_networkx_nodes(graph, pos, node_size=700, alpha=0.95, node_color=colors, edgecolors=[0,0,0])
    nx.draw_networkx_labels(graph, pos, alpha=1, labels=labels, font_color=[1,1,1], font_size=12)
    plt.axis('equal')
    plt.axis('off')
    plt.show()
    
    
def normalize_pos(pos):    
    """
    Normalizes a nx pos dictionary to be within a [0,1]^2 box
    """
    x = np.array([pos[node][0] for node in pos.keys()])
    x = (x - min(x))/(max(x) - min(x) + 1e-10)
    y = np.array([pos[node][1] for node in pos.keys()])
    y = (y - min(y))/(max(y) - min(y) + 1e-10)    
    return dict(zip(pos.keys(), list(zip(x,y))))


def hierarchy_pos(G, root=None, width=1., vert_gap = 0.2, vert_loc = 0, xcenter = 0.5):

    '''
    From Joel's answer at https://stackoverflow.com/a/29597209/2966723.  
    Licensed under Creative Commons Attribution-Share Alike 

    If the graph is a tree this will return the positions to plot this in a 
    hierarchical layout.

    G: the graph (must be a tree)

    root: the root node of current branch 
    - if the tree is directed and this is not given, 
      the root will be found and used
    - if the tree is directed and this is given, then 
      the positions will be just for the descendants of this node.
    - if the tree is undirected and not given, 
      then a random choice will be used.

    width: horizontal space allocated for this branch - avoids overlap with other branches
    vert_gap: gap between levels of hierarchy
    vert_loc: vertical location of root
    xcenter: horizontal location of root
    '''
    if not nx.is_tree(G):
        raise TypeError('cannot use hierarchy_pos on a graph that is not a tree')

    if root is None:
        if isinstance(G, nx.DiGraph):
            root = next(iter(nx.topological_sort(G)))  #allows back compatibility with nx version 1.11
        else:
            root = random.choice(list(G.nodes))

    def _hierarchy_pos(G, root, width=1., vert_gap = 0.2, vert_loc = 0, xcenter = 0.5, pos = None, parent = None):
        '''
        see hierarchy_pos docstring for most arguments
        pos: a dict saying where all nodes go if they have been assigned
        parent: parent of this branch. - only affects it if non-directed

        '''

        if pos is None:
            pos = {root:(xcenter,vert_loc)}
        else:
            pos[root] = (xcenter, vert_loc)
        children = list(G.neighbors(root))
        if not isinstance(G, nx.DiGraph) and parent is not None:
            children.remove(parent)  
        if len(children)!=0:
            dx = width/len(children) 
            nextx = xcenter - width/2 - dx/2
            for child in children:
                nextx += dx
                pos = _hierarchy_pos(G,child, width = dx, vert_gap = vert_gap, 
                                    vert_loc = vert_loc-vert_gap, xcenter=nextx,
                                    pos=pos, parent = root)
        return pos


    return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter)