#!/usr/bin/env python
# coding: utf-8

from __future__ import division
from __future__ import print_function

import argparse
import time
import math
import numpy as np
from subprocess import check_output

import networkx as nx
import numpy as np
import scipy.sparse as sp
import torch
from torch import nn
import torch.nn.functional as F
from torch import optim

import collections
import re
import pandas as pd

import community
from sklearn.cluster import KMeans

import community
import matplotlib.pyplot as plt

def get_assignment(G, model, num_classes):
    n_nodes = G.number_of_nodes()
    num_classes = categorical_dim
    res = np.zeros((n_nodes, num_classes))
    model.eval()
    edges = [(u,v) for u,v in G.edges()]
    batch = torch.LongTensor(edges)
    q = model(batch[:, 0], batch[:, 1], None, 1., True)
    for idx, e in enumerate(edges):
        res[e[0], :] += q[idx, :].cpu().data.numpy()
        res[e[1], :] += q[idx, :].cpu().data.numpy()

    assignment = {}
    res = res.argmax(axis=-1)
    assignment = {i : res[i] for i in range(res.shape[0])}
    return assignment

def classical_modularity_calculator(graph, embedding, model='gcn_vae', cluster_number=5):
    """
    Function to calculate the DeepWalk cluster centers and assignments.
    """    
    if model == 'gcn_vae':
        assignments = embedding
    else:
        kmeans = KMeans(n_clusters=cluster_number, random_state=0, n_init = 1).fit(embedding)
        assignments = {i: int(kmeans.labels_[i]) for i in range(0, embedding.shape[0])}

    modularity = community.modularity(assignments, graph)
    return modularity

class GCNModelGumbel(nn.Module):
    def __init__(self, size, embedding_dim, categorical_dim, dropout, device):
        super(GCNModelGumbel, self).__init__()
        self.embedding_dim = embedding_dim
        self.categorical_dim = categorical_dim
        self.device = device
        self.size = size

        self.community_embeddings = nn.Linear(embedding_dim, categorical_dim, bias=False).to(device)
        self.node_embeddings = nn.Embedding(size, embedding_dim)
        # self.contextnode_embeddings = nn.Embedding(size, embedding_dim)
        self.contextnode_embeddings = nn.Linear(size, embedding_dim, bias=False).to(device)

        self.init_emb()

        self.include_w = False

    def init_emb(self):
        initrange = -1.5 / self.embedding_dim
        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.fill_(0.0)

    def forward(self, w, c, neg, temp ,getq=False):

        v_i = w
        v_j = c
        negsamples = neg
        
        w = self.node_embeddings(w).to(self.device)
        c = self.node_embeddings(c).to(self.device)

        q = self.community_embeddings(w*c)
        if getq:
            return F.softmax(q, dim=-1)
        # q.shape: [batch_size, categorical_dim]
        # z = self._sample_discrete(q, temp)
        if self.training:
            z = F.gumbel_softmax(logits=q, tau=temp, hard=True)
        else:
            tmp = q.argmax(dim=-1).reshape(q.shape[0], 1)
            z = torch.zeros(q.shape).to(self.device).scatter_(1, tmp, 1.)

        prior = self.community_embeddings(w)
        prior = F.softmax(prior, dim=-1)
        # prior.shape [batch_num_nodes, 

        # z.shape [batch_size, categorical_dim]
        new_z = torch.mm(z, self.community_embeddings.weight)
        
        recon = torch.mm(new_z, self.contextnode_embeddings.weight)
        return recon, F.softmax(q, dim=-1), prior
            

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='s', help="models used")
parser.add_argument('--lamda', type=float, default=100, help="")
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--epochs', type=int, default=2000000000000000000000, help='Number of epochs to train.')
parser.add_argument('--embedding-dim', type=int, default=128, help='')
parser.add_argument('--lr', type=float, default=0.05, help='Initial learning rate.')
parser.add_argument('--dropout', type=float, default=0., help='Dropout rate (1 - keep probability).')
parser.add_argument('--dataset-str', type=str, default='facebook0', help='type of dataset.')
# parser.add_argument('--task', type=str, default='community', help='type of dataset.')

args = parser.parse_args('')
args.dataset_str = 'facebook0'
args.lr = 0.05
args.lamda = 0

embedding_dim = args.embedding_dim
cur_lr = lr = args.lr
epochs = args.epochs
temp = 1.
temp_min = 0.1
ANNEAL_RATE = 0.00003

import sys
sys.path.append('../')
from data_utils import load_dataset
G, adj, gt_communities = load_dataset(args.dataset_str, relabel=True)

n_nodes = G.number_of_nodes()
print('n_nodes', n_nodes)
print('n_edges', G.number_of_edges())
categorical_dim = len(gt_communities)
categorical_dim = 50
print('categorical_dim', categorical_dim)
gt_communities = None

adj_orig = adj
adj_orig = adj_orig - sp.dia_matrix((adj_orig.diagonal()[np.newaxis, :], [0]), shape=adj_orig.shape)
adj_orig.eliminate_zeros()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = GCNModelGumbel(adj.shape[0], embedding_dim, categorical_dim, args.dropout, device)

optimizer = optim.Adam(model.parameters(), lr=lr)

hidden_emb = None
history_valap = []
history_mod = []

train_edges = [(u,v) for u,v in G.edges()]
n_nodes = G.number_of_nodes()
print('len(train_edges)', len(train_edges))
print('calculating normalized_overlap')
from score_utils import normalized_overlap
overlap = torch.Tensor([normalized_overlap(G,u,v) for u,v in train_edges]).to(device)

def loss_function(recon_c, q_y, prior, c, norm=None, pos_weight=None):
    
    BCE = F.cross_entropy(recon_c, c, reduction='sum') / c.shape[0]

    log_qy = torch.log(q_y  + 1e-20)
    KLD = torch.sum(q_y*(log_qy - torch.log(prior)),dim=-1).mean()

    ent = (- torch.log(q_y) * q_y).sum(dim=-1).mean()
    return BCE + KLD

for epoch in range(args.epochs):
    print("Epoch {}".format(epoch))
    batch = torch.LongTensor(train_edges)

    model.train()
    optimizer.zero_grad()
    
    w = batch[:, 0]
    c = batch[:, 1]
    negsamples = None
    
    L, q, prior = model(w, c, negsamples, temp)
    
    loss = loss_function(L, q, prior, c.to(device), None, None)
    if np.isnan(loss.item()):
        print('training stopped because NAN')
        input()
        break

    if args.lamda > 0:
        res = torch.zeros([n_nodes, categorical_dim], requires_grad=True, dtype=torch.float32).to(device)
        tmp_w, tmp_c = batch[:, 0], batch[:, 1]
        for idx, e in enumerate(batch):
            res[e[0], :] += q[idx, :]/G.degree(e[0])
            res[e[1], :] += q[idx, :]/G.degree(e[1])

        tmp = ((res[tmp_w] - res[tmp_c])**2).mean(dim=-1)
        assert overlap.shape == tmp.shape
        ttmp = overlap * tmp
        smoothing_loss = (overlap*tmp).mean()
        loss += args.lamda * smoothing_loss

    loss.backward()
    cur_loss = loss.item()
    optimizer.step()

    temp = np.maximum(temp*np.exp(-ANNEAL_RATE*epoch),temp_min)

    model.eval()
    assert not model.training 

    assignment = get_assignment(G, model, categorical_dim)
    modularity = classical_modularity_calculator(G, assignment)
    if args.lamda > 0:
        print("Epoch:", '%04d' % (epoch + 1),
                      "temp:", '{:.5f}'.format(temp),
                      "train_loss=", "{:.5f}".format(cur_loss),
                      "smoothing_loss=", "{:.5f}".format(args.lamda * smoothing_loss.item()),
                      "modularity=", "{:.5f}".format(modularity))
    else:
        print("Epoch:", '%04d' % (epoch + 1),
                      "temp:", '{:.5f}'.format(temp),
                      "train_loss=", "{:.5f}".format(cur_loss),
                      "modularity=", "{:.5f}".format(modularity))
    if epoch % 100 == 0:
        cur_lr *= .99
        for param_group in optimizer.param_groups:
            param_group['lr'] = cur_lr

print("Optimization Finished!")
