from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import time

import sys
import math
sys.path.insert(0, '../graph_methods')
sys.path.insert(0, '../src')
from dataloader import *
from graph_util import num_atom_features, num_bond_features
from function import reshape_data_into_2_dim
from graph_neural_util import degrees, neural_fingerprint_collate_fn


def rmse(X, Y):
    print('mse: {}'.format(np.mean((X-Y)**2)))
    return np.sqrt(np.mean((X - Y)**2))


def tensor_to_variable(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x.float())


def long_tensor_to_variable(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x.long())


def variable_to_numpy(x):
    if torch.cuda.is_available():
        x = x.cpu()
    x = x.data.numpy()
    return x


def sum_and_stack(features, idxs_list_of_lists):
    stacked = []
    for idx_list in idxs_list_of_lists:
        index = idx_list
        temp = features.index_select(dim=0, index=index).sum(dim=0, keepdim=True)
        stacked.append(temp)
    stacked = torch.cat(stacked, dim=0)
    return stacked


def matmult_neighbors(layer, array_rep, atom_features, bond_features,
                      num_hidden_features, with_degree_layers):
    activations_by_degree = []
    for degree in degrees:
        atom_neighbors_list = array_rep[('atom_neighbors', degree)]
        bond_neighbors_list = array_rep[('bond_neighbors', degree)]

        if len(atom_neighbors_list) == 0:
            continue

        if degree == 0:
            activations = tensor_to_variable(
                torch.FloatTensor(len(atom_neighbors_list), num_hidden_features[layer]).zero_())
        else:
            atom_neighbors_index = atom_neighbors_list
            bond_neighbors_index = bond_neighbors_list

            atom_neighbors_features = []
            bond_neighbors_features = []
            for index in atom_neighbors_index:
                atom_neighbors_features.append(atom_features[index].sum(dim=0, keepdim=True))
            atom_neighbors_features = torch.cat(atom_neighbors_features, dim=0)
            for index in bond_neighbors_index:
                bond_neighbors_features.append(bond_features[index].sum(dim=0, keepdim=True))
            bond_neighbors_features = torch.cat(bond_neighbors_features, dim=0)
            neighbor_features = [atom_neighbors_features, bond_neighbors_features]
            neighbor_features = torch.cat(neighbor_features, dim=1)
            summed_neighbors = neighbor_features
            activations = with_degree_layers[layer][degree](summed_neighbors)
        activations_by_degree.append(activations)

    final_activations = torch.cat(activations_by_degree, dim=0)
    return final_activations


class Neural_FP_Model(nn.Module):
    def __init__(self, conf):
        super(Neural_FP_Model, self).__init__()

        layer_config = conf['layer_config']['layers']
        task_num = len(conf['label_name_list'])
        self.fps_length = conf['layer_config']['neural']['fps_length']
        self.num_hidden_features = conf['layer_config']['neural']['num_hidden_features']

        self.all_layer_sizes = [num_atom_features()] + self.num_hidden_features
        self.in_and_out_sizes = zip(self.all_layer_sizes[:-1], self.all_layer_sizes[1:])

        self.output_fps_layer = []
        self.all_fps_layers = []
        self.all_fps_layers_batch_normalizer = []
        self.with_degree_layers = []

        # Initial neural fingerprints layers
        for layer_id in range(len(self.all_layer_sizes)):
            self.output_fps_layer.append(nn.Linear(self.all_layer_sizes[layer_id], self.fps_length))
        self.output_fps_layer = nn.ModuleList(self.output_fps_layer)

        for layer_id, (N_prev, N_cur) in enumerate(self.in_and_out_sizes):
            self.all_fps_layers.append(nn.Linear(N_prev, N_cur))
            self.all_fps_layers_batch_normalizer.append(nn.BatchNorm1d(N_cur))
            layer_with_degree = []
            for degree in degrees:
                layer_with_degree.append(nn.Linear(N_prev+num_bond_features(), N_cur, bias=False))
            layer_with_degree = nn.ModuleList(layer_with_degree)
            self.with_degree_layers.append(layer_with_degree)
        self.all_fps_layers = nn.ModuleList(self.all_fps_layers)
        self.all_fps_layers_batch_normalizer = nn.ModuleList(self.all_fps_layers_batch_normalizer)
        self.with_degree_layers = nn.ModuleList(self.with_degree_layers)

        layer_config = [self.fps_length] + layer_config + [task_num]
        self.fc_length = len(layer_config)

        self.layer_sequence = nn.Sequential()

        for layer_id, (layer_first, layer_second) in enumerate(zip(layer_config[:-1], layer_config[1:])):
            layer = nn.Sequential()
            layer.add_module('linear', nn.Linear(layer_first, layer_second))
            if layer_id < self.fc_length - 2:
                layer.add_module('activation', nn.ReLU())
                layer.add_module('batchnorm', nn.BatchNorm1d(layer_second))
            self.layer_sequence.add_module('layer {}'.format(layer_id), layer)

    def forward(self, array_rep):
        neural_fingerprint_layer = []
        atom_features = array_rep['atom_features']
        bond_features = array_rep['bond_features']

        def write_to_fingerprints(atom_features, layer):
            atom_outputs = F.softmax(self.output_fps_layer[layer](atom_features))
            temp_outputs = sum_and_stack(atom_outputs, array_rep['atom_list'])
            neural_fingerprint_layer.append(temp_outputs)

        def update_layer(layer_id, atom_features, bond_features, array_rep):
            self_activations = self.all_fps_layers[layer_id](atom_features)
            neighbor_activations = matmult_neighbors(layer_id, array_rep,
                                                     atom_features, bond_features,
                                                     self.num_hidden_features, self.with_degree_layers)
            total_activations = self_activations + neighbor_activations
            total_activations = self.all_fps_layers_batch_normalizer[layer_id](total_activations)
            total_activations = F.relu(total_activations)
            return total_activations

        # Build up neural fingerprints layers
        num_layers = len(self.num_hidden_features)
        for layer in xrange(num_layers):
            write_to_fingerprints(atom_features, layer)
            atom_features = update_layer(layer, atom_features, bond_features, array_rep)
        write_to_fingerprints(atom_features, num_layers)

        x = neural_fingerprint_layer[0]
        for i in range(1, len(neural_fingerprint_layer)):
            x += neural_fingerprint_layer[i]
        neural_fingerprint_layer = x

        output = self.layer_sequence(neural_fingerprint_layer)
        return output

    def loss_(self, y_predicted, y_actual, size_average=True):
        criterion = nn.MSELoss(size_average=size_average)
        loss = criterion(y_predicted, y_actual)
        return loss


class RegressionTask:
    def __init__(self, conf, **kwargs):
        self.conf = conf
        self.model_weight_file = kwargs['file_path']
        self.label_name_list = conf['label_name_list']

        # Build up model
        self.model = Neural_FP_Model(conf=self.conf)
        print(self.model)
        if torch.cuda.is_available():
            self.model.cuda()

        torch.manual_seed(conf['seed'])
        self.model.apply(self.weights_init)

        self.optimizer = optim.Adam(self.model.parameters(), lr=conf['learning_rate'], weight_decay=conf['l2_weight_decay'])
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.5, patience=5,
                                                              min_lr=conf['min_learning_rate'], verbose=True)

        self.epoch = conf['epoch']
        self.batch_size = conf['batch_size']

        self.train_dataloader = torch.utils.data.DataLoader(
            kwargs['training_dataset'], batch_size=self.batch_size,
            shuffle=True, collate_fn=neural_fingerprint_collate_fn)
        self.test_dataloader = torch.utils.data.DataLoader(
            kwargs['test_dataset'], batch_size=self.batch_size,
            shuffle=False, collate_fn=neural_fingerprint_collate_fn)

    def weights_init(self, m):
        classname = m.__class__.__name__
        if 'Linear' in classname:
            m.weight.data.normal_(0.0, 0.02)
            if m.bias is not None:
                m.bias.data.fill_(1.0)
        elif 'BatchNorm1d' in classname:
            m.weight.data.normal_(1.0, 0.02)
            if m.bias is not None:
                m.bias.data.fill_(0)
        return

    def make_prediction(self, dataloader):
        self.model.eval()
        actual = []
        pred = []
        for j, (X_batch, y_batch) in enumerate(dataloader):
            X_batch, y_batch = self.customized_tensor_to_variable(X_batch, y_batch)

            y_pred = self.model(X_batch)
            actual_label = variable_to_numpy(y_batch)
            actual_label = reshape_data_into_2_dim(actual_label)
            pred_label = variable_to_numpy(y_pred)

            actual.append(actual_label)
            pred.append(pred_label)
        return np.vstack(actual), np.vstack(pred)

    def train(self, data_loader):
        self.model.train()
        total_loss = 0
        for batch_id, (X_batch, y_batch) in enumerate(data_loader):
            X_batch, y_batch = self.customized_tensor_to_variable(X_batch, y_batch)
            self.optimizer.zero_grad()

            y_pred = self.model(X_batch)
            loss = self.model.loss_(y_predicted=y_pred, y_actual=y_batch, size_average=False)
            total_loss += loss.data[0]
            loss.backward()
            self.optimizer.step()

        total_loss /= len(data_loader.dataset)
        return total_loss

    def train_and_predict(self):
        for e in range(self.epoch):
            print('Epoch: {}'.format(e))
            train_loss = self.train(self.train_dataloader)
            self.scheduler.step(train_loss)
            print('Train loss: {}'.format(train_loss))

            if e % 10 == 0:
                y_train, y_pred_on_train = self.make_prediction(self.train_dataloader)
                rmse_train = rmse(y_pred_on_train, y_train)
                print('RMSE on train set: {}'.format(rmse_train))
                if self.test_dataloader is not None:
                    y_test, y_pred_on_test = self.make_prediction(self.test_dataloader)
                    rmse_test = rmse(y_pred_on_test, y_test)
                    print('RMSE on test set: {}'.format(rmse_test))
                print()

        print()
        y_train, y_pred_on_train = self.make_prediction(self.train_dataloader)
        rmse_train = rmse(y_pred_on_train, y_train)
        print('RMSE on train set: {}'.format(rmse_train))
        if self.test_dataloader is not None:
            y_test, y_pred_on_test = self.make_prediction(self.test_dataloader)
            rmse_test = rmse(y_pred_on_test, y_test)
            print('RMSE on test set: {}'.format(rmse_test))
        self.save_model(self.model_weight_file)
        return

    def save_model(self, file_path):
        print('file path\t', file_path)
        with open(file_path, 'wb') as f_:
            torch.save(self.model, f_)
        return

    def load_best_model(self):
        with open(self.model_weight_file, 'rb') as f_:
            self.model = torch.load(f_)
        return self.model

    def load_model(self, file_path):
        with open(file_path, 'rb') as f_:
            self.model = torch.load(f_)
        return self.model

    def customized_tensor_to_variable(self, array_rep, y_batch):
        array_rep['atom_features'] = tensor_to_variable(torch.FloatTensor(array_rep['atom_features']))
        array_rep['bond_features'] = tensor_to_variable(torch.FloatTensor(array_rep['bond_features']))

        for degree in degrees[1:]:
            atom_neighbors_list = array_rep[('atom_neighbors', degree)]
            bond_neighbors_list = array_rep[('bond_neighbors', degree)]
            if len(atom_neighbors_list) == 0:
                continue
            array_rep[('atom_neighbors', degree)] = long_tensor_to_variable(torch.LongTensor(atom_neighbors_list))
            array_rep[('bond_neighbors', degree)] = long_tensor_to_variable(torch.LongTensor(bond_neighbors_list))

        idxs_list_of_lists = array_rep['atom_list']
        for i, idx_list in enumerate(idxs_list_of_lists):
            array_rep['atom_list'][i] = long_tensor_to_variable(torch.LongTensor(idx_list))

        y_batch = tensor_to_variable(torch.FloatTensor(y_batch).float())
        return array_rep, y_batch


if __name__ == '__main__':
    import logging
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    conf = {
        'layer_config': {
            'neural': {
                'fps_length': 512,
                'num_hidden_features': [50, 30],
            },
            'layers': [2048, 512]
        },
        'epoch': 100,
        'batch_size':100,
        'learning_rate': 1e-3,
        'min_learning_rate': 1e-5,
        'l2_weight_decay': 0.0001,

        'seed': 1337,
        'label_name_list': ['delaney'],
    }

    label_name_list = conf['label_name_list']

    K = 5
    directory = '../datasets/delaney/{}.csv.gz'
    file_list = []
    for i in range(K):
        file_list.append(directory.format(i))

    test_index = slice(0, 1)
    train_index = slice(1, 5)
    train_file_list = file_list[train_index]
    test_file_list = file_list[test_index]
    print('train files ', train_file_list)
    print('test files ', test_file_list)

    train_dataset = SMILESDataSet(train_file_list,
                                  feature_name='SMILES',
                                  label_name_list=label_name_list)
    test_dataset = SMILESDataSet(test_file_list,
                                 feature_name='SMILES',
                                 label_name_list=label_name_list)
    print('Done Loading Test')

    kwargs = {'file_path': './temp.pt', 'training_dataset': train_dataset, 'test_dataset': test_dataset}

    task = RegressionTask(conf=conf, **kwargs)
    task.train_and_predict()