from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from dataloader import *

from graph_util import num_atom_features, num_bond_features
from collections import OrderedDict
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

import sys
sys.path.insert(0, '../src')
from util import output_classification_result


def tensor_to_variable(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x.float())


def variable_to_numpy(x):
    if torch.cuda.is_available():
        x = x.cpu()
    x = x.data.numpy()
    return x


class Flatten(nn.Module):
    def forward(self, x):
        x = x.view(x.size()[0], -1)
        return x


class MappingLayer(nn.Module):
    def __init__(self, max_atom_num, latent_dim):
        super(MappingLayer, self).__init__()
        self.parameter = nn.Parameter(torch.randn(max_atom_num, latent_dim).cuda(), requires_grad=True)
        self.register_parameter('mapping_layer', self.parameter)

    def forward(self, node_attr_matrix, adjacent_matrix, distance_matrix):
        mapping_layer = torch.mm(self.parameter, self.parameter.transpose(0, 1))
        logging.debug('mapping_layer\t', mapping_layer.size())

        adjacent_out = torch.mul(adjacent_matrix, mapping_layer)
        logging.debug('adjacent out\t', adjacent_out.size())

        distance_out = torch.mul(distance_matrix, mapping_layer)
        logging.debug('distance out\t', distance_out.size())

        node_attr_out = torch.matmul(node_attr_matrix, self.parameter)
        logging.debug('adjacent sum\t', node_attr_out.size())
        return node_attr_out, adjacent_out, distance_out


class GraphModel(nn.Module):
    def __init__(self, max_atom_num, atom_attr_dim, bond_attr_dim, latent_dim):
        super(GraphModel, self).__init__()
        self.max_atom_num = max_atom_num
        self.atom_attr_dim = atom_attr_dim
        self.bond_attr_dim = bond_attr_dim
        self.latent_dim = latent_dim

        self.mapping_layers = nn.ModuleList([
            MappingLayer(max_atom_num=max_atom_num, latent_dim=1),
        ])

        self.fc_layer = nn.Sequential(
            nn.Linear(self.max_atom_num*2 + self.atom_attr_dim, 1),
            nn.Sigmoid(),
        )

        for name, parameter in self.named_parameters():
            print(name, '\t', parameter.size())
        print()

    def forward(self, node_attr_matrix, adjacent_matrix, distance_matrix):
        node_attr_matrix = node_attr_matrix.float()
        adjacent_matrix = adjacent_matrix.float()
        node_attr_matrix = node_attr_matrix.transpose(1, 2)
        logging.debug('node attr matrix\t', node_attr_matrix.size())

        for layer in self.mapping_layers:
            node_attr_matrix, adjacent_matrix, distance_matrix = layer(node_attr_matrix=node_attr_matrix,
                                                                       adjacent_matrix=adjacent_matrix,
                                                                       distance_matrix=distance_matrix)

        logging.debug('adjacent sum\t', adjacent_matrix.size())
        logging.debug('distance sum\t', distance_matrix.size())
        adjacent_sum = torch.sum(adjacent_matrix, dim=2, keepdim=True)
        distance_sum = torch.sum(distance_matrix, dim=2, keepdim=True)
        logging.debug('adjacent sum\t', adjacent_sum.size())
        logging.debug('distance sum\t', distance_sum.size())
        logging.debug('node attr sum\t', node_attr_matrix.size())

        x = torch.cat([adjacent_sum, distance_sum, node_attr_matrix], dim=1)
        x = torch.squeeze(x)
        logging.debug('x\t', x.size())

        x = self.fc_layer(x)
        return x

    def loss_(self, y_predicted, y_actual, alpha=1e-3, size_average=True):
        sample_weight = 1 + y_actual * (1.0 / 0.01)
        criterion = nn.BCELoss(weight=sample_weight, size_average=size_average)
        loss = criterion(y_predicted, y_actual)

        l1_criterion = nn.L1Loss(size_average=False)
        regularizer = 0
        for layer in self.mapping_layers:
            layer_parameter = layer.parameter
            target = tensor_to_variable(torch.zeros(layer_parameter.size()))
            regularizer += l1_criterion(layer_parameter, target)
        total_loss = loss + regularizer * alpha

        return total_loss


def visualize(model):
    params = model.state_dict()
    for k, v in sorted(params.items()):
        print(k, v.shape)
    for name, param in model.named_parameters():
        print(name, '\t', param.requires_grad, '\t', param.data)
    return


def train(data_loader):
    graph_model.train()
    total_loss = 0
    for batch_id, (adjacent_matrix, distance_matrix, node_attr_matrix, y_label) in enumerate(data_loader):
        adjacent_matrix = tensor_to_variable(adjacent_matrix)
        distance_matrix = tensor_to_variable(distance_matrix)
        node_attr_matrix = tensor_to_variable(node_attr_matrix)
        y_label = tensor_to_variable(y_label)
        y_pred = graph_model(adjacent_matrix=adjacent_matrix, distance_matrix=distance_matrix, node_attr_matrix=node_attr_matrix)
        loss = graph_model.loss_(y_predicted=y_pred, y_actual=y_label, size_average=False)
        total_loss += loss.data[0]
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    total_loss /= len(data_loader)
    return total_loss


def make_predictions(data_loader):
    if data_loader is None:
        return None, None
    y_label_list = []
    y_pred_list = []
    for batch_id, (adjacent_matrix, distance_matrix, node_attr_matrix, y_label) in enumerate(data_loader):
        adjacent_matrix = tensor_to_variable(adjacent_matrix)
        distance_matrix = tensor_to_variable(distance_matrix)
        node_attr_matrix = tensor_to_variable(node_attr_matrix)
        y_label = tensor_to_variable(y_label)
        y_pred = graph_model(adjacent_matrix=adjacent_matrix, distance_matrix=distance_matrix, node_attr_matrix=node_attr_matrix)
        y_label_list.extend(variable_to_numpy(y_label))
        y_pred_list.extend(variable_to_numpy(y_pred))
    y_label_list = np.array(y_label_list)
    y_pred_list = np.array(y_pred_list)
    return y_label_list, y_pred_list


def test(train_dataloader=None, test_dataloader=None):
    graph_model.eval()
    y_train, y_pred_on_train = make_predictions(train_dataloader)
    y_test, y_pred_on_test = make_predictions(test_dataloader)
    output_classification_result(y_train=y_train, y_pred_on_train=y_pred_on_train,
                                 y_val=None, y_pred_on_val=None,
                                 y_test=y_test, y_pred_on_test=y_pred_on_test,
                                 EF_ratio_list=[0.001, 0.0015, 0.01, 0.02])
    return


def save_model(weight_path):
    print('Saving weight path:\t', weight_path)
    with open(weight_path, 'wb') as f_:
        torch.save(graph_model, f_)


def load_best_model(weight_path):
    with open(weight_path, 'rb') as f_:
        graph_model = torch.load(f_)
    return graph_model


if __name__ == '__main__':
    import time
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--epoch', action='store', dest='epoch',
                        type=int, required=False, default=20)
    parser.add_argument('--batch_size', action='store', dest='batch_size',
                        type=int, required=False, default=128)
    parser.add_argument('--learning_rate', action='store', dest='learning_rate',
                        type=float, required=False, default=1e-3)
    parser.add_argument('--min_learning_rate', action='store', dest='min_learning_rate',
                        type=float, required=False, default=1e-5)
    parser.add_argument('--seed', action='store', dest='seed',
                        type=int, required=False, default=123)
    given_args = parser.parse_args()

    K = 5
    directory = '../datasets/keck_pria_lc/{}_graph.npz'
    file_list = []
    for i in range(K):
        file_list.append(directory.format(i))
    file_list.append('../datasets/keck_pria_lc/keck_lc4_graph.npz')

    EPOCHS = given_args.epoch
    BATCH = given_args.batch_size
    MAX_ATOM_NUM = 55
    LATENT_DIM = 55
    ATOM_FEATURE_DIM = num_atom_features()
    BOND_FEATURE_DIM = num_bond_features()
    torch.manual_seed(given_args.seed)

    graph_model = GraphModel(max_atom_num=MAX_ATOM_NUM,
                             atom_attr_dim=ATOM_FEATURE_DIM,
                             bond_attr_dim=BOND_FEATURE_DIM,
                             latent_dim=LATENT_DIM)
    if torch.cuda.is_available():
        graph_model.cuda()
    # graph_model.apply(weights_init)
    # visualize(graph_model)
    print(graph_model)

    train_graph_matrix_file = file_list[:5]
    test_graph_matrix_file = file_list[5]

    train_dataset = GraphDataSet_Distance_Adjacent(train_graph_matrix_file)
    test_dataset = GraphDataSet_Distance_Adjacent(test_graph_matrix_file)

    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH, shuffle=True)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH, shuffle=False)

    optimizer = optim.Adam(graph_model.parameters(), lr=given_args.learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.2, patience=3,
                                                     min_lr=given_args.min_learning_rate, verbose=True)

    for epoch in range(EPOCHS):
        print('Epoch: {}'.format(epoch))

        train_start_time = time.time()
        train_loss = train(train_dataloader)
        scheduler.step(train_loss)
        train_end_time = time.time()
        print('Train time: {:.3f}s. Train loss is {}.'.format(train_end_time - train_start_time, train_loss))

        if epoch % 10 == 0:
            test_start_time = time.time()
            test(train_dataloader=train_dataloader, test_dataloader=None)
            test_end_time = time.time()
            print('Test time: {:.3f}s.'.format(test_end_time - test_start_time))
            print()

    test_start_time = time.time()
    test(train_dataloader=train_dataloader, test_dataloader=test_dataloader)
    test_end_time = time.time()
    print('Test time: {:.3f}s.'.format(test_end_time - test_start_time))
    print()
