"""
From the paper:

We use the “mean” variant of GRAPHSAGE [16]
and apply a DIFFPOOL layer after every two GRAPHSAGE layers in our architecture. A total of 2
DIFFPOOL layers are used for the datasets. 
For small datasets such as ENZYMES and COLLAB, 1
DIFFPOOL layer can achieve similar performance. 
After each DIFFPOOL layer, 3 layers of graph
convolutions are performed, before the next DIFFPOOL layer, or the readout layer.
 The embedding matrix and the assignment matrix are computed by two separate 
 GRAPHSAGE models respectively.
In the 2 DIFFPOOL layer architecture, the number of clusters is set as 25% of the number
 of nodes before applying DIFFPOOL, 
 
 [OUR INITIAL CONFIGURATION WHILE OPERATING ON SMALL DATASETS]
 while in the 1 DIFFPOOL layer architecture, the number of clusters is set
as 10%. Batch normalization [18] is applied after every layer of GRAPHSAGE. 
We also found that adding an `2 normalization to the node embeddings at each 
layer made the training more stable. 
In Section 4.2, we also test an analogous variant of DIFFPOOL on the STRUCTURE2VEC [7] 
architecture,
in order to demonstrate how DIFFPOOL can be applied on top of other GNN models. 
All models are trained for 3 000 epochs
with early stopping applied when the validation loss starts to drop.
"""
import os
import numpy as np
import torch
import dgl
import argparse
import time

import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from dgl.data import tu

from diffpool.model.encoder import DiffPool
from diffpool.data_utils import pre_process

from data_handler import dataloader
from sklearn.model_selection import train_test_split 
import pickle
from tqdm import tqdm
#from sys import argv

#%%

def prepare_data(dataset, prog_args, train=False, pre_process=None):
    '''
    preprocess TU dataset according to DiffPool's paper setting and load dataset into dataloader
    '''
    if train:
        shuffle = True
    else:
        shuffle = False

    if pre_process:
        pre_process(dataset, prog_args)

    # dataset.set_fold(fold)
    return dgl.dataloading.GraphDataLoader(dataset,
                                           batch_size=prog_args.batch_size,
                                           shuffle=shuffle,
                                           num_workers=prog_args.n_worker)


def graph_classify_task(prog_args):
    '''
    perform graph classification task
    '''
    abspath = os.path.abspath('../')
    experiment_name = '/diffpool1_gc3_splitseed%s/'%prog_args.split_traintest_seed
    res_repo = abspath+'/results_diffpool/%s/'%prog_args.dataset_name
    save_dir = res_repo + experiment_name
    prog_args.save_dir = save_dir
    # the order of graphs is the same between dgl loader and ours !
    dataset = tu.LegacyTUDataset(name=str_datasetname_to_tudataset[prog_args.dataset_name])
    # Split the dataset as we did in our benchmark
    _,labels = dataloader.load_local_data(prog_args.data_path, prog_args.dataset_name, one_hot=prog_args.one_hot)                
    dataset_size = len(labels)
    idx_train, idx_test, y_train, y_test = train_test_split(np.arange(dataset_size), labels, test_size=0.1, stratify=labels, random_state=prog_args.split_traintest_seed)
    idx_subtrain, idx_val, y_subtrain, y_val = train_test_split(np.arange(len(y_train)),y_train, test_size=0.1, stratify=y_train, random_state=prog_args.split_trainval_seed)
    # get proper idx_subtrain and idx_val from idx_train
    true_idx_subtrain = [idx_train[i] for i in idx_subtrain]
    true_idx_val = [idx_train[i] for i in idx_val]
    
    
    dataset_train = torch.utils.data.Subset(dataset, true_idx_subtrain)
    dataset_val = torch.utils.data.Subset(dataset, true_idx_val)
    dataset_test = torch.utils.data.Subset(dataset, idx_test)
    train_dataloader = prepare_data(dataset_train, prog_args, train=True,
                                    pre_process=pre_process)
    val_dataloader = prepare_data(dataset_val, prog_args, train=False,
                                  pre_process=pre_process)
    test_dataloader = prepare_data(dataset_test, prog_args, train=False,
                                   pre_process=pre_process)
    input_dim, label_dim, max_num_node = dataset.statistics()
    print("++++++++++STATISTICS ABOUT THE DATASET")
    print("dataset feature dimension is", input_dim)
    print("dataset label dimension is", label_dim)
    print("the max num node is", max_num_node)
    print("number of graphs is", len(dataset))
    # assert len(dataset) % prog_args.batch_size == 0, "training set not divisible by batch size"

    hidden_dim = 64  # used to be 64
    embedding_dim = 64

    # calculate assignment dimension: pool_ratio * largest graph's maximum
    # number of nodes  in the dataset
    assign_dim = int(max_num_node * prog_args.pool_ratio)
    print("++++++++++MODEL STATISTICS++++++++")
    print("model hidden dim is", hidden_dim)
    print("model embedding dim for graph instance embedding", embedding_dim)
    print("initial batched pool graph dim is", assign_dim)
    activation = F.relu

    # initialize model
    # 'diffpool' : diffpool
    model = DiffPool(input_dim,
                     hidden_dim,
                     embedding_dim,
                     label_dim,
                     activation,
                     prog_args.gc_per_block,
                     prog_args.dropout,
                     prog_args.num_pool,
                     prog_args.linkpred,
                     prog_args.batch_size,
                     'meanpool',
                     assign_dim,
                     prog_args.pool_ratio)
    
    #if prog_args.load_epoch >= 0 and prog_args.save_dir is not None:
        #model.load_state_dict(torch.load(prog_args.save_dir + "/" + prog_args.dataset
        #                                 + "/model.iter-" + str(prog_args.load_epoch)))
        
    print("model init finished")
    print("MODEL:::::::", prog_args.method)
    if prog_args.cuda:
        model = model.cuda()

    early_stopping_logger, training_log = train(
        train_dataloader,
        model,
        prog_args,
        val_dataset=val_dataloader)
    result = evaluate(test_dataloader, model, prog_args, early_stopping_logger)
    print("test  accuracy {:.2f}%".format(result * 100))
    training_log['test_accuracy'] = result
    str_log = prog_args.save_dir+'/model_splitvalseed%s_training_log.pkl'%prog_args.split_trainval_seed
    pickle.dump([training_log, early_stopping_logger] , open(str_log,'wb'))

def train(dataset, model, prog_args, same_feat=True, val_dataset=None, verbose=False):
    '''
    training function
    '''
    dir = prog_args.save_dir
    if not os.path.exists(dir):
        os.makedirs(dir)
    dataloader = dataset
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        model.parameters()), lr=0.001)
    early_stopping_logger = {"best_epoch": -1, "val_acc": -1}
    training_log = {}
    for set_ in ['train', 'val']:
        #training_log['%s_epoch_loss'%set_] = []
        training_log['%s_accuracy'%set_] = []
        training_log['validated_%s_accuracy'%set_] = []
        
    
    if prog_args.cuda > 0:
        torch.cuda.set_device(0)
    for epoch in tqdm(range(prog_args.epoch)):
        begin_time = time.time()
        model.train()
        accum_correct = 0
        total = 0
        if verbose:
            print("\nEPOCH ###### {} ######".format(epoch))
        computation_time = 0.0
        for (batch_idx, (batch_graph, graph_labels)) in enumerate(dataloader):
            for (key, value) in batch_graph.ndata.items():
                batch_graph.ndata[key] = value.float()
            graph_labels = graph_labels.long()
            if torch.cuda.is_available():
                batch_graph = batch_graph.to(torch.cuda.current_device())
                graph_labels = graph_labels.cuda()

            model.zero_grad()
            compute_start = time.time()
            ypred = model(batch_graph)
            indi = torch.argmax(ypred, dim=1)
            correct = torch.sum(indi == graph_labels).item()
            accum_correct += correct
            total += graph_labels.size()[0]
            loss = model.loss(ypred, graph_labels)
            loss.backward()
            batch_compute_time = time.time() - compute_start
            computation_time += batch_compute_time
            nn.utils.clip_grad_norm_(model.parameters(), prog_args.clip)
            optimizer.step()

        train_accu = accum_correct / total
        if verbose:
            print("train accuracy for this epoch {} is {:.2f}%".format(epoch,
                                                                   train_accu * 100))
        training_log['train_accuracy'].append(train_accu)
        
        if verbose:
            elapsed_time = time.time() - begin_time
            print("loss {:.4f} with epoch time {:.4f} s & computation time {:.4f} s ".format(
            loss.item(), elapsed_time, computation_time))
            global_train_time_per_epoch.append(elapsed_time)
        if val_dataset is not None:
            result = evaluate(val_dataset, model, prog_args)
            training_log['val_accuracy'].append(result)
            if verbose:
                print("validation  accuracy {:.2f}%".format(result * 100))
            if result >= early_stopping_logger['val_acc'] and result <= train_accu:
                early_stopping_logger.update(best_epoch=epoch, val_acc=result)
                training_log['validated_val_accuracy'].append(result)
                str_log = prog_args.save_dir+'/model_splitvalseed%s_training_log.pkl'%prog_args.split_trainval_seed
                pickle.dump([training_log,early_stopping_logger] , open(str_log,'wb'))
                if prog_args.save_dir is not None:
                    torch.save(model.state_dict(), prog_args.save_dir + '/model_splitvalseed%s_best_val_accuracy_increasing_train_accuracy.pkl'%prog_args.split_trainval_seed)

                #if prog_args.save_dir is not None:
                #    torch.save(model.state_dict(), prog_args.save_dir + "/" + prog_args.dataset
                #               + "/model.iter-" + str(early_stopping_logger['best_epoch']))
            if verbose:
                print("best epoch is EPOCH {}, val_acc is {:.2f}%".format(early_stopping_logger['best_epoch'],
                                                                      early_stopping_logger['val_acc'] * 100))
        torch.cuda.empty_cache()
    return early_stopping_logger, training_log


def evaluate(dataloader, model, prog_args, logger=None):
    '''
    evaluate function
    '''
    if logger is not None and prog_args.save_dir is not None:
        #model.load_state_dict(torch.load(prog_args.save_dir + "/" + prog_args.dataset
        #                                 + "/model.iter-" + str(logger['best_epoch'])))
        model.load_state_dict(torch.load(prog_args.save_dir + '/model_splitvalseed%s_best_val_accuracy_increasing_train_accuracy.pkl'%prog_args.split_trainval_seed))

    model.eval()
    correct_label = 0
    with torch.no_grad():
        for batch_idx, (batch_graph, graph_labels) in enumerate(dataloader):
            for (key, value) in batch_graph.ndata.items():
                batch_graph.ndata[key] = value.float()
            graph_labels = graph_labels.long()
            if torch.cuda.is_available():
                batch_graph = batch_graph.to(torch.cuda.current_device())
                graph_labels = graph_labels.cuda()
            ypred = model(batch_graph)
            indi = torch.argmax(ypred, dim=1)
            correct = torch.sum(indi == graph_labels)
            correct_label += correct.item()
    result = correct_label / (len(dataloader) * prog_args.batch_size)
    return result
#%%

# python train_diffpool.py -ds 'imdb-b' -seed 0
global_train_time_per_epoch = []
str_datasetname_to_tudataset = {'mutag':'MUTAG',
                                'ptc':'PTC_MR',
                                'enzymes':'ENZYMES',
                                'protein':'PROTEINS_full',
                                'nci1':'NCI1',
                                'imdb-b':'IMDB-BINARY',
                                'imdb-m':'IMDB-MULTI',
                                'collab':'COLLAB'}
abspath = os.path.abspath('../')
data_path = abspath+'/real_datasets/'

parser = argparse.ArgumentParser(description='DiffPool arguments')
#parser.add_argument('-ds', '--dataset_name', type=str, dest='dataset_name', help='Input Dataset', required=True)
#parser.add_argument('-seed', '--split_trainval_seed', dest='split_trainval_seed', type=int, help='split train val', required =True)

    
parser.set_defaults(
    dataset_name ='mutag',
    split_trainval_seed = 0,
    pool_ratio=0.15,
    num_pool=1,
    cuda=0,
    lr=1e-3,
    clip=2.0,
    batch_size=20,
    epoch=3000,
    split_traintest_seed = 0,
    n_worker=1,
    gc_per_block=3,
    dropout=0.0,
    method='diffpool',
    bn=True,
    bias=True,
    linkpred=True,
    data_path = data_path,
    one_hot = False,
    load_epoch=-1,
    data_mode='default')

prog_args= parser.parse_args()


print(prog_args)
model=graph_classify_task(prog_args)
    
    #print("Train time per epoch: {:.4f}".format( sum(global_train_time_per_epoch) / len(global_train_time_per_epoch) ))
    #print("Max memory usage: {:.4f}".format(torch.cuda.max_memory_allocated(0) / (1024 * 1024)))

"""
count_params = 0
params_set = [model]
for set_p in params_set:
    for p in set_p.parameters():
        if len(p.shape)==2:
            count_params += p.shape[0] * p.shape[1]
        elif len(p.shape)==1:
            count_params += p.shape[0]
print('count_params:', count_params)
"""