from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from dataloader import *

from graph_util import num_atom_features, num_bond_features
from collections import OrderedDict
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

import sys
sys.path.insert(0, '../src')
from util import output_classification_result


target_name_to_hit_ratio = {
    'NR-AR': 0.0415885461053,
    'NR-AR-LBD': 0.034836817015,
    'NR-AhR': 0.118862559242,
    'NR-Aromatase': 0.0510356609011,
    'NR-ER': 0.125826487678,
    'NR-ER-LBD': 0.0495367070563,
    'NR-PPAR-gamma': 0.0286263208453,
    'SR-ARE': 0.161624709118,
    'SR-ATAD5': 0.0367582706109,
    'SR-HSE': 0.0577032946106,
    'SR-MMP': 0.156383890317,
    'SR-p53': 0.0608950843727,
}


def tensor_to_variable(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x.float())


def variable_to_numpy(x):
    if torch.cuda.is_available():
        x = x.cpu()
    x = x.data.numpy()
    return x


class Flatten(nn.Module):
    def forward(self, x):
        x = x.view(x.size()[0], -1)
        return x


class Message_Passing(nn.Module):
    def forward(self, x, adjacent_matrix):
        neighbor_nodes = torch.bmm(adjacent_matrix, x)
        logging.debug('neighbor message\t', neighbor_nodes.size())
        x = x + neighbor_nodes
        logging.debug('x shape\t', x.size())
        return x


class GraphModel(nn.Module):
    def __init__(self, max_atom_num, atom_attr_dim, bond_attr_dim, latent_dim):
        super(GraphModel, self).__init__()
        self.max_atom_num = max_atom_num
        self.atom_attr_dim = atom_attr_dim
        self.bond_attr_dim = bond_attr_dim
        self.latent_dim = latent_dim

        self.graph_modules = nn.Sequential(OrderedDict([
            ('message_passing_0', Message_Passing()),
            ('dense_0', nn.Linear(self.atom_attr_dim, 20)),
            ('activation_0', nn.Sigmoid()),
            ('message_passing_1', Message_Passing()),
            ('dense_1', nn.Linear(20, 10)),
            ('activation_1', nn.Sigmoid()),
        ]))

        self.fully_connected = nn.Sequential(
            Flatten(),
            nn.Linear(self.max_atom_num*10, 1024),
            nn.ReLU(),
            nn.Linear(1024, 500),
            nn.Sigmoid(),
            nn.Linear(500, 1),
            nn.Sigmoid(),
        )

    def forward(self, node_attr_matrix, adjacent_matrix):
        node_attr_matrix = node_attr_matrix.float()
        adjacent_matrix = adjacent_matrix.float()
        x = node_attr_matrix
        logging.debug('shape\t', x.size())

        for (name, module) in self.graph_modules.named_children():
            if 'message_passing' in name:
                x = module(x, adjacent_matrix=adjacent_matrix)
            else:
                x = module(x)

        x = self.fully_connected(x)
        return x

    def loss_(self, y_predicted, y_actual, size_average=True):
        sample_weight = 1 + y_actual * (1.0 / target_name_to_hit_ratio[target_name])
        criterion = nn.BCELoss(weight=sample_weight, size_average=size_average)
        loss = criterion(y_predicted, y_actual)
        return loss


def visualize(model):
    params = model.state_dict()
    for k, v in sorted(params.items()):
        print(k, v.shape)
    for name, param in model.named_parameters():
        print(name, '\t', param.requires_grad, '\t', param.data)
    return


def train(data_loader):
    graph_model.train()
    total_loss = 0
    for batch_id, (adjacent_matrix, node_attr_matrix, y_label) in enumerate(data_loader):
        adjacent_matrix = tensor_to_variable(adjacent_matrix)
        node_attr_matrix = tensor_to_variable(node_attr_matrix)
        y_label = tensor_to_variable(y_label)
        y_pred = graph_model(adjacent_matrix=adjacent_matrix, node_attr_matrix=node_attr_matrix)
        loss = graph_model.loss_(y_predicted=y_pred, y_actual=y_label, size_average=False)
        total_loss += loss.data[0]
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    total_loss /= len(data_loader)
    return total_loss


def make_predictions(data_loader):
    if data_loader is None:
        return None, None
    y_label_list = []
    y_pred_list = []
    for batch_id, (adjacent_matrix, node_attr_matrix, y_label) in enumerate(data_loader):
        adjacent_matrix = tensor_to_variable(adjacent_matrix)
        node_attr_matrix = tensor_to_variable(node_attr_matrix)
        y_label = tensor_to_variable(y_label)
        y_pred = graph_model(adjacent_matrix=adjacent_matrix, node_attr_matrix=node_attr_matrix)
        y_label_list.extend(variable_to_numpy(y_label))
        y_pred_list.extend(variable_to_numpy(y_pred))
    y_label_list = np.array(y_label_list)
    y_pred_list = np.array(y_pred_list)
    return y_label_list, y_pred_list


def test(train_dataloader=None, test_dataloader=None):
    graph_model.eval()
    y_train, y_pred_on_train = make_predictions(train_dataloader)
    y_test, y_pred_on_test = make_predictions(test_dataloader)
    output_classification_result(y_train=y_train, y_pred_on_train=y_pred_on_train,
                                 y_val=None, y_pred_on_val=None,
                                 y_test=y_test, y_pred_on_test=y_pred_on_test,
                                 EF_ratio_list=[0.001, 0.0015, 0.01, 0.02],
                                 hit_ratio=target_name_to_hit_ratio[target_name])
    return


def save_model(weight_path):
    print('Saving weight path:\t', weight_path)
    with open(weight_path, 'wb') as f_:
        torch.save(graph_model, f_)


def load_best_model(weight_path):
    with open(weight_path, 'rb') as f_:
        graph_model = torch.load(f_)
    return graph_model


if __name__ == '__main__':
    import time
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--epoch', action='store', dest='epoch',
                        type=int, required=False, default=20)
    parser.add_argument('--batch_size', action='store', dest='batch_size',
                        type=int, required=False, default=128)
    parser.add_argument('--learning_rate', action='store', dest='learning_rate',
                        type=float, required=False, default=1e-3)
    parser.add_argument('--min_learning_rate', action='store', dest='min_learning_rate',
                        type=float, required=False, default=1e-5)
    parser.add_argument('--seed', action='store', dest='seed',
                        type=int, required=False, default=123)
    parser.add_argument('--target_name', action='store', dest='target_name',
                        type=str, required=False, default='NR-AhR')
    given_args = parser.parse_args()

    K = 5
    target_name = given_args.target_name
    directory = '../datasets/tox21/{}/{}_graph.npz'
    file_list = []
    for i in range(K):
        file_list.append(directory.format(target_name, i))

    EPOCHS = given_args.epoch
    BATCH = given_args.batch_size
    MAX_ATOM_NUM = 55
    LATENT_DIM = 50
    ATOM_FEATURE_DIM = num_atom_features()
    BOND_FEATURE_DIM = num_bond_features()
    torch.manual_seed(given_args.seed)

    graph_model = GraphModel(max_atom_num=MAX_ATOM_NUM,
                             atom_attr_dim=ATOM_FEATURE_DIM,
                             bond_attr_dim=BOND_FEATURE_DIM,
                             latent_dim=LATENT_DIM)
    if torch.cuda.is_available():
        graph_model.cuda()
    # graph_model.apply(weights_init)
    # visualize(graph_model)
    print(graph_model)

    train_graph_matrix_file = file_list[:4]
    test_graph_matrix_file = file_list[4]

    train_dataset = GraphDataSet_Adjacent(train_graph_matrix_file)
    test_dataset = GraphDataSet_Adjacent(test_graph_matrix_file)

    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH, shuffle=True)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH, shuffle=False)

    optimizer = optim.Adam(graph_model.parameters(), lr=given_args.learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.2, patience=3,
                                                     min_lr=given_args.min_learning_rate, verbose=True)

    for epoch in range(EPOCHS):
        print('Epoch: {}'.format(epoch))

        train_start_time = time.time()
        train_loss = train(train_dataloader)
        scheduler.step(train_loss)
        train_end_time = time.time()
        print('Train time: {:.3f}s. Train loss is {}.'.format(train_end_time - train_start_time, train_loss))

        if epoch % 10 == 0:
            test_start_time = time.time()
            test(train_dataloader=train_dataloader, test_dataloader=None)
            test_end_time = time.time()
            print('Test time: {:.3f}s.'.format(test_end_time - test_start_time))
            print()

    test_start_time = time.time()
    test(train_dataloader=train_dataloader, test_dataloader=test_dataloader)
    test_end_time = time.time()
    print('Test time: {:.3f}s.'.format(test_end_time - test_start_time))
    print()
