from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from dataloader import *

from graph_util import num_atom_features, num_bond_features
from collections import OrderedDict
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

import sys
sys.path.insert(0, '../src')
from util import output_classification_result


def tensor_to_variable(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x.float())


def variable_to_numpy(x):
    if torch.cuda.is_available():
        x = x.cpu()
    x = x.data.numpy()
    return x


class Flatten(nn.Module):
    def forward(self, x):
        x = x.view(x.size()[0], -1)
        return x


class MappingLayer(nn.Module):
    def __init__(self, max_atom_num, latent_dim):
        super(MappingLayer, self).__init__()
        self.parameter = nn.Parameter(torch.randn(max_atom_num, latent_dim).cuda(), requires_grad=True)
        self.register_parameter('mapping_layer', self.parameter)

    def forward(self, node_attr_matrix, adjacent_matrix, distance_matrix):
        mapping_layer = torch.mm(self.parameter, self.parameter.transpose(0, 1))
        logging.debug('mapping_layer\t', mapping_layer.size())

        adjacent_out = torch.mul(adjacent_matrix, mapping_layer)
        logging.debug('adjacent out\t', adjacent_out.size())

        distance_out = torch.mul(distance_matrix, mapping_layer)
        logging.debug('distance out\t', distance_out.size())

        node_attr_out = torch.matmul(node_attr_matrix, self.parameter)
        logging.debug('adjacent sum\t', node_attr_out.size())
        return node_attr_out, adjacent_out, distance_out


class GraphModel(nn.Module):
    def __init__(self, max_atom_num, atom_attr_dim, bond_attr_dim, latent_dim):
        super(GraphModel, self).__init__()
        self.max_atom_num = max_atom_num
        self.atom_attr_dim = atom_attr_dim
        self.bond_attr_dim = bond_attr_dim
        self.latent_dim = latent_dim

        self.graph_modules = nn.Sequential(OrderedDict([
            ('message_passing_0', Message_Passing()),
            ('dense_0', nn.Linear(self.atom_attr_dim, 10)),
            ('activation_0', nn.ReLU()),
        ]))

        self.fully_connected = nn.Sequential(
            Flatten(),
            nn.Linear(self.max_atom_num*10, 1024),
            nn.ReLU(),
            nn.Linear(1024, 500),
            nn.ReLU(),
            nn.Linear(500, 1),
            nn.Sigmoid(),
        )

    def forward(self, node_attr_matrix, adjacent_matrix, distance_matrix):
        node_attr_matrix = node_attr_matrix.float()
        adjacent_matrix = adjacent_matrix.float()
        x = node_attr_matrix
        logging.debug('shape\t', x.size())

        for (name, module) in self.graph_modules.named_children():
            if 'message_passing' in name:
                x = module(x, adjacent_matrix=adjacent_matrix, distance_matrix=distance_matrix)
            else:
                x = module(x)

        x = self.fully_connected(x)

        return x


def visualize(model):
    params = model.state_dict()
    for k, v in sorted(params.items()):
        print(k, v.shape)
    for name, param in model.named_parameters():
        print(name, '\t', param.requires_grad, '\t', param.data)
    return


def train(data_loader):
    graph_model.train()
    total_loss = 0
    for batch_id, (adjacent_matrix, distance_matrix, node_attr_matrix, y_label) in enumerate(data_loader):
        # print('Batch id: {}'.format(batch_id))
        adjacent_matrix = tensor_to_variable(adjacent_matrix)
        distance_matrix = tensor_to_variable(distance_matrix)
        node_attr_matrix = tensor_to_variable(node_attr_matrix)
        y_label = tensor_to_variable(y_label)
        y_predicted = graph_model(adjacent_matrix=adjacent_matrix, distance_matrix=distance_matrix, node_attr_matrix=node_attr_matrix)
        # sample_weight = 1 + y_label * 999
        # criterion = nn.BCELoss(weight=sample_weight, size_average=True)
        criterion = nn.BCELoss(size_average=True)
        loss = criterion(y_predicted, y_label)
        total_loss += loss.data[0]
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    total_loss /= len(data_loader)
    return total_loss


def make_predictions(data_loader):
    if data_loader is None:
        return None, None
    y_label_list = []
    y_pred_list = []
    for batch_id, (adjacent_matrix, distance_matrix, node_attr_matrix, y_label) in enumerate(data_loader):
        adjacent_matrix = tensor_to_variable(adjacent_matrix)
        distance_matrix = tensor_to_variable(distance_matrix)
        node_attr_matrix = tensor_to_variable(node_attr_matrix)
        y_label = tensor_to_variable(y_label)
        y_pred = graph_model(adjacent_matrix=adjacent_matrix, distance_matrix=distance_matrix, node_attr_matrix=node_attr_matrix)
        y_label_list.extend(variable_to_numpy(y_label))
        y_pred_list.extend(variable_to_numpy(y_pred))
    y_label_list = np.array(y_label_list)
    y_pred_list = np.array(y_pred_list)
    return y_label_list, y_pred_list


def test(train_dataloader=None, val_dataloader=None, test_dataloader=None):
    graph_model.eval()
    y_train, y_pred_on_train = make_predictions(train_dataloader)
    y_val, y_pred_on_val = make_predictions(val_dataloader)
    y_test, y_pred_on_test = make_predictions(test_dataloader)
    output_classification_result(y_train=y_train, y_pred_on_train=y_pred_on_train,
                                 y_val=y_val, y_pred_on_val=y_pred_on_val,
                                 y_test=y_test, y_pred_on_test=y_pred_on_test,
                                 EF_ratio_list=[0.001, 0.0015, 0.01, 0.02])
    return


def save_model(weight_path):
    print('Saving weight path:\t', weight_path)
    with open(weight_path, 'wb') as f_:
        torch.save(graph_model, f_)


def load_best_model(weight_path):
    with open(weight_path, 'rb') as f_:
        graph_model = torch.load(f_)
    return graph_model


if __name__ == '__main__':
    import time
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--epoch', action='store', dest='epoch',
                        type=int, required=False, default=20)
    parser.add_argument('--batch_size', action='store', dest='batch_size',
                        type=int, required=False, default=128)
    parser.add_argument('--seed', action='store', dest='seed',
                        type=int, required=False, default=123)
    given_args = parser.parse_args()

    K = 5
    directory = '../datasets/keck_pria_lc/{}_graph.npz'
    file_list = []
    for i in range(K):
        file_list.append(directory.format(i))
    file_list.append('../datasets/keck_pria_lc/keck_lc4_graph.npz')

    EPOCHS = given_args.epoch
    BATCH = given_args.batch_size
    MAX_ATOM_NUM = 55
    LATENT_DIM = 50
    ATOM_FEATURE_DIM = num_atom_features()
    BOND_FEATURE_DIM = num_bond_features()
    torch.manual_seed(given_args.seed)

    graph_model = GraphModel(max_atom_num=MAX_ATOM_NUM,
                           atom_attr_dim=ATOM_FEATURE_DIM,
                           bond_attr_dim=BOND_FEATURE_DIM,
                           latent_dim=LATENT_DIM)
    if torch.cuda.is_available():
        graph_model.cuda()
    # graph_model.apply(weights_init)
    # visualize(graph_model)
    print(graph_model)

    train_graph_matrix_file = file_list[:4]
    val_graph_matrix_file = file_list[4]
    test_graph_matrix_file = file_list[5]

    train_dataset = GraphDataSet_Distance_Adjacent(train_graph_matrix_file)
    val_dataset = GraphDataSet_Distance_Adjacent(val_graph_matrix_file)
    test_dataset = GraphDataSet_Distance_Adjacent(test_graph_matrix_file)

    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH, shuffle=True)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH, shuffle=False)

    optimizer = optim.Adam(graph_model.parameters(), lr=1e-3)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.2, patience=5, min_lr=1e-4, verbose=True)
    criterion = nn.BCELoss(size_average=True)

    for epoch in range(EPOCHS):
        print('Epoch: {}'.format(epoch))

        train_start_time = time.time()
        train_loss = train(train_dataloader)
        train_end_time = time.time()
        print('Train time: {:.3f}s. Train loss is {}.'.format(train_end_time - train_start_time, train_loss))

        val_start_time = time.time()
        val_loss = val(val_dataloader)
        scheduler.step(val_loss)
        val_end_time = time.time()
        print('Valid time: {:.3f}s. Val loss is {}.'.format(val_end_time - val_start_time, val_loss))
        print()

        if epoch % 10 == 0:
            test_start_time = time.time()
            test(train_dataloader=train_dataloader, val_dataloader=val_dataloader, test_dataloader=None)
            test_end_time = time.time()
            print('Test time: {:.3f}s.'.format(test_end_time - test_start_time))
            print()

    test_start_time = time.time()
    test(train_dataloader=train_dataloader, val_dataloader=val_dataloader, test_dataloader=test_dataloader)
    test_end_time = time.time()
    print('Test time: {:.3f}s.'.format(test_end_time - test_start_time))
    print()
