import argparse, time
import numpy as np
import networkx as nx
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
from models.dgi_cora import DGI, MultiClassifier
from models.subgi_models_v2 import SubGI
from models import VGAE #, SUBVGAE
from IPython import embed
import scipy.sparse as sp
from collections import defaultdict
from torch.autograd import Variable
from src.misc.graph_generation.create_graphs import create, generate_label
from tqdm import tqdm
import pickle
from collections import defaultdict


def evaluate(model, features, labels, mask):
    model.eval()
    with torch.no_grad():
        logits = model(features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

def degree_bucketing(graph, args, degree_emb=None, max_degree = 10):
    #G = nx.DiGraph(graph)
    #embed()
    max_degree = args.n_hidden
    features = torch.ones([graph.number_of_nodes(), max_degree])
    return features
    # embed()
    for i in range(graph.number_of_nodes()):
        #print(i)

        try:
            features[i][min(graph.in_degree(i), max_degree-1)] = 1
            # features[i, :] = degree_emb[min(graph.degree(i), max_degree-1), :]
        except:
            features[i][0] = 1
            #features[i, :] = degree_emb[0, :]
    # embed()
    #embed()
    return features

def createTraining(labels, valid_mask = None, train_ratio=0.8):
    train_mask = torch.zeros(labels.shape, dtype=torch.bool)
    test_mask = torch.ones(labels.shape, dtype=torch.bool)
    
    num_train = int(labels.shape[0] * train_ratio)
    all_node_index = list(range(labels.shape[0]))
    np.random.shuffle(all_node_index)
    #for i in range(len(idx) * train_ratio):
    # embed()
    train_mask[all_node_index[:num_train]] = 1
    test_mask[all_node_index[:num_train]] = 0
    if valid_mask is not None:
        train_mask *= valid_mask
        test_mask *= valid_mask
    return train_mask, test_mask

def read_struct_net(file_path):
    #g = DGLGraph()
    g = nx.Graph()
    #g.add_nodes(1000)
    with open(file_path) as IN:
        for line in IN:
            tmp = line.strip().split()
            # print(tmp[0], tmp[1])
            g.add_edge(int(tmp[0]), int(tmp[1]))
    return g
    #g.add_nodes(len(graph_a.id2idx) + len(graph_b.id2idx))
    
    #g.add_edges(graph_a.edge_src, graph_a.edge_dst)
    #g.add_edges(graph_a.edge_dst, graph_a.edge_src)
    
def constructDGL(graph):
    node_mapping = defaultdict(int)
    #relabels = []
    for node in sorted(list(graph.nodes())):
        node_mapping[node] = len(node_mapping)
    #    relabels.append(labels[node])
    # embed()
    #assert len(node_mapping) == len(labels)
    new_g = DGLGraph()
    new_g.add_nodes(len(node_mapping))
    #for i in range(len(node_mapping)):
    #    new_g.add_edge(i, i)
    for edge in graph.edges():
        if not new_g.has_edge_between(node_mapping[edge[0]], node_mapping[edge[1]]):
            new_g.add_edge(node_mapping[edge[0]], node_mapping[edge[1]])
        if not new_g.has_edge_between(node_mapping[edge[1]], node_mapping[edge[0]]):
            new_g.add_edge(node_mapping[edge[1]], node_mapping[edge[0]])
    
    # embed()
    return new_g 

def output_adj(graph):
    A = np.zeros([graph.number_of_nodes(), graph.number_of_nodes()])
    a,b = graph.all_edges()
    for id_a, id_b in zip(a.numpy().tolist(), b.numpy().tolist()):
        A[id_a, id_b] = 1
    return A

def compute_term(l, r):
    w_l,_ = np.linalg.eig(l)
    w_r,_ = np.linalg.eig(r)
    w_l.sort()
    w_r.sort()
    padding_shape = max(w_l.shape[0], w_r.shape[0])
    padding_vector = np.zeros((padding_shape))
    if w_l.shape[0] < padding_shape:
        padding_vector[padding_shape-w_l.shape[0]:] = w_l
        w_l = padding_vector
    elif w_r.shape[0] < padding_shape:
        padding_vector[padding_shape-w_r.shape[0]:] = w_r
        w_r = padding_vector
    return np.linalg.norm(w_l-w_r)
    '''
    padding_shape = max(l.shape[0], r.shape[0])
    if l.shape[0] < padding_shape:
        new_l = np.zeros([padding_shape, padding_shape])
        new_l[:l.shape[0], :l.shape[0]] = l
        l = new_l
    elif r.shape[0] < padding_shape:
        new_r = np.zeros([padding_shape, padding_shape])
        new_r[:r.shape[0], :r.shape[0]] = r
        r = new_r
    #try:
    w,v = np.linalg.eig(l - r)
    #except:
    #    embed()
    w.sort()
    '''
    # embed()
    return w[-1].astype(float)

# dump the best run
def main(args):
    # load and preprocess dataset

    def constructSubG(file_path):
        g = read_struct_net(file_path)
        if True:
            g.remove_edges_from(nx.selfloop_edges(g))
        g = constructDGL(g)
        g.readonly()
        # n_edges = g.number_of_edges()

        node_sampler = dgl.contrib.sampling.NeighborSampler(g, 1, -1,  # 0,
                                                                neighbor_type='in', num_workers=1,
                                                                add_self_loop=False,
                                                                num_hops=args.n_layers + 1, shuffle=True)
        L_list = []
        degree = []
        cnt = 0
        for ego_g in tqdm(node_sampler):
            idx_coding = dict()
            
            for layer_id in range(args.n_layers+2)[::-1]:
                for node_id in ego_g.layer_parent_nid(layer_id).numpy().tolist():
                    if node_id not in idx_coding:
                        idx_coding[node_id] = len(idx_coding)
                        # degree.append(g.in_degree())
            A = np.zeros([len(idx_coding), len(idx_coding)])
            
            for i in range(ego_g.num_blocks):
                u,v = g.find_edges(ego_g.block_parent_eid(i))
                for left_id, right_id in zip(u.numpy().tolist(), v.numpy().tolist()):
                    A[idx_coding[left_id], idx_coding[right_id]] = 1
            # lower part is the in-degree direction
            A = A.T
            i_lower = np.tril_indices(A.shape[0], -1)
            A[i_lower] = A.T[i_lower]
            degree.append(A.sum(0)[0])
            D = np.diag(A.sum(0))
            L = D - A
            D_ = np.diag(1.0 / np.sqrt(A.sum(0)))
            normailized_L = np.matmul(np.matmul(D_,L),D_)
            if np.isnan(normailized_L.sum()):
                embed()
            L_list.append(normailized_L)
            cnt += 1
            #if cnt > 100:
            #    break
        return L_list, max(degree)
    print(args.file_path, args.label_path)
    L_list, max_degree_l = constructSubG(args.file_path)
    R_list, max_degree_r = constructSubG(args.label_path)
    
    # embed()
    bound = 0
    for l in tqdm(L_list):
        for r in R_list:
            bound += compute_term(l, r)
    print(bound / (len(L_list) * len(R_list)))
    embed()      
        # embed()


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='DGI')
    register_data_args(parser)
    parser.add_argument("--dropout", type=float, default=0.0,
                        help="dropout probability")
    parser.add_argument("--gpu", type=int, default=-1,
                        help="gpu")
    parser.add_argument("--dgi-lr", type=float, default=1e-2,
                        help="dgi learning rate")
    parser.add_argument("--classifier-lr", type=float, default=1e-2,
                        help="classifier learning rate")
    parser.add_argument("--n-dgi-epochs", type=int, default=300,
                        help="number of training epochs")
    parser.add_argument("--n-classifier-epochs", type=int, default=100,
                        help="number of training epochs")
    parser.add_argument("--n-hidden", type=int, default=32,
                        help="number of hidden gcn units")
    parser.add_argument("--n-layers", type=int, default=1,
                        help="number of hidden gcn layers")
    parser.add_argument("--weight-decay", type=float, default=0.,
                        help="Weight for L2 loss")
    parser.add_argument("--patience", type=int, default=20,
                        help="early stop patience condition")
    parser.add_argument("--model", action='store_true',
                        help="graph self-loop (default=False)")
    parser.add_argument("--self-loop", action='store_true',
                        help="graph self-loop (default=False)")
    parser.add_argument("--model-type", type=int, default=2,
                    help="graph self-loop (default=False)")
    parser.add_argument("--graph-type", type=str, default="DD",
                    help="graph self-loop (default=False)")
    parser.add_argument("--data-id", type=str,
                    help="[usa, europe, brazil]")
    parser.add_argument("--data-src", type=str, default='',
                    help="[usa, europe, brazil]")
    parser.add_argument("--file-path", type=str,
                        help="graph path")
    parser.add_argument("--label-path", type=str,
                        help="label path")
    parser.add_argument("--model-id", type=int, default=0,
                    help="[0, 1, 2, 3]")

    parser.set_defaults(self_loop=False)
    args = parser.parse_args()
    print(args)
    
    main(args)