from __future__ import print_function

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader


class FingerprintsDataSet(Dataset):
    def __init__(self, data_path, feature_name, label_name_list):
        if isinstance(data_path, str):
            data_path = [data_path]

        self.fingerprints_list = []
        self.labels = []
        for path in data_path:
            data_pd = pd.read_csv(path)
            print(data_pd.columns)

            fingerprints = data_pd[feature_name].tolist()
            fingerprints = map(lambda x: list(x), fingerprints)
            self.fingerprints_list.extend(fingerprints)
            labels = data_pd[label_name_list].values.tolist()
            self.labels.extend(labels)

        self.fingerprints_list = np.array(self.fingerprints_list)
        self.fingerprints_list = self.fingerprints_list.astype(float)
        self.labels = np.array(self.labels)
        self.labels = self.labels.astype(float)
        print('data shape\t', self.fingerprints_list.shape)

    def __len__(self):
        return len(self.fingerprints_list)

    def __getitem__(self, idx):
        fingerprints = self.fingerprints_list[idx]
        label = self.labels[idx]

        fingerprints = torch.from_numpy(fingerprints)
        label = torch.from_numpy(label)
        return fingerprints, label


class SMILESDataSet(Dataset):
    def __init__(self, data_path, feature_name, label_name_list):
        if isinstance(data_path, str):
            data_path = [data_path]

        self.smiles_list = []
        self.labels = []
        for path in data_path:
            data_pd = pd.read_csv(path)
            print(data_pd.columns)

            fingerprints = data_pd[feature_name].tolist()
            self.smiles_list.extend(fingerprints)
            labels = data_pd[label_name_list].values.tolist()
            self.labels.extend(labels)

        self.labels = np.array(self.labels)
        self.labels = self.labels.astype(float)
        print('data length\t', len(self.smiles_list))

    def __len__(self):
        return len(self.smiles_list)

    def __getitem__(self, idx):
        smiles = self.smiles_list[idx]
        label = self.labels[idx]
        return smiles, label


class Fingerprints_Random_Projection_DataSet(Dataset):
    def __init__(self, data_file_list, feature_name, label_name_list, n_gram_num, is_sum_up_feature_segment=False):
        if isinstance(data_file_list, str):
            data_file_list = [data_file_list]

        self.fingerprints_list = []
        self.labels = []
        for data_file in data_file_list:
            data = np.load(data_file)
            print(data.keys())
            X_data_temp = data[feature_name]
            if X_data_temp.ndim == 4:
                print('original size\t', X_data_temp.shape)
                X_data_temp = X_data_temp[:, :n_gram_num, ...]
                print('truncated size\t', X_data_temp.shape)
                molecule_num, _, random_projection_dimension, segmentation_num = X_data_temp.shape
                if is_sum_up_feature_segment:
                    X_data_temp = X_data_temp.sum(axis=-1)
                    print('sum up along feature segment axis, shape is\t', X_data_temp.shape)
                    segmentation_num = 1
                X_data_temp = X_data_temp.reshape(
                    (molecule_num, n_gram_num * random_projection_dimension * segmentation_num), order='F')
                print('flatten feature segment with random projection, shape is\t', X_data_temp.shape)

            self.fingerprints_list.extend(X_data_temp)

            y_data_temp = map(lambda x: data[x], label_name_list)
            y_data_temp = np.stack(y_data_temp, axis=1)
            self.labels.extend(y_data_temp)

        self.fingerprints_list = np.stack(self.fingerprints_list)
        self.labels = np.stack(self.labels)
        print('fingerprints shape\t', self.fingerprints_list.shape)
        print('labels shape\t', self.labels.shape)

    def __len__(self):
        return len(self.fingerprints_list)

    def __getitem__(self, idx):
        fingerprints = self.fingerprints_list[idx]
        label = self.labels[idx]

        fingerprints = torch.from_numpy(fingerprints)
        label = torch.from_numpy(label)
        return fingerprints, label


class CharacterSMILESDataSet(Dataset):
    def __init__(self, data_path, is_training=False, is_validation=False, split_ratio=0.):
        data = np.load(data_path, 'r')
        self.one_hot_matrix = data['one_hot_matrix']
        if is_training:
            N = int(len(self) * split_ratio)
            self.one_hot_matrix = self.one_hot_matrix[N:]
        if is_validation:
            N = int(len(self) * split_ratio)
            self.one_hot_matrix = self.one_hot_matrix[:N]
        print('data shape\t', self.one_hot_matrix.shape)

    def __len__(self):
        return len(self.one_hot_matrix)

    def __getitem__(self, idx):
        one_hot = self.one_hot_matrix[idx]
        one_hot = torch.from_numpy(one_hot)
        return one_hot


class GraphDataSet_Adjacent(Dataset):
    def __init__(self, data_path):
        if isinstance(data_path, str):
            data_path = [data_path]

        self.adjacent_matrix = []
        self.node_attr_matrix = []
        # self.bond_attr_matrix = []
        self.label = []

        for path in data_path:
            data = np.load(path)
            self.adjacent_matrix.extend(data['adjacent_matrix_list'])
            self.node_attr_matrix.extend(data['node_attribute_matrix_list'])
            # self.bond_attr_matrix.extend(data['bond_attribute_matrix_list'])
            self.label.extend(data['label_name'])

        self.adjacent_matrix = np.stack(self.adjacent_matrix)
        self.node_attr_matrix = np.stack(self.node_attr_matrix)
        # self.bond_attr_matrix = np.stack(self.bond_attr_matrix)
        self.label = np.stack(self.label)

        print('adjacent matrix:\t', self.adjacent_matrix.shape)
        print('node attribute matrix:\t', self.node_attr_matrix.shape)
        # print('bond attribute matrix:\t', self.bond_attr_matrix.shape)
        print('label_name:\t\t', self.label.shape)

    def __len__(self):
        return len(self.adjacent_matrix)

    def __getitem__(self, idx):
        adjacent_matrix = self.adjacent_matrix[idx]
        node_attr_matrix = self.node_attr_matrix[idx]
        label = self.label[idx:idx+1]

        adjacent_matrix = torch.from_numpy(adjacent_matrix)
        node_attr_matrix = torch.from_numpy(node_attr_matrix)
        label = torch.from_numpy(label)
        return adjacent_matrix, node_attr_matrix, label


class GraphDataSet_Distance_Adjacent(Dataset):
    def __init__(self, data_path):
        if isinstance(data_path, str):
            data_path = [data_path]

        self.adjacent_matrix = []
        self.distance_matrix = []
        self.node_attr_matrix = []
        self.label = []

        for path in data_path:
            data = np.load(path)
            self.adjacent_matrix.extend(data['adjacent_matrix_list'])
            self.distance_matrix.extend(data['distance_matrix_list'])
            self.node_attr_matrix.extend(data['node_attribute_matrix_list'])
            self.label.extend(data['label_name'])

        self.adjacent_matrix = np.stack(self.adjacent_matrix)
        self.node_attr_matrix = np.stack(self.node_attr_matrix)
        self.distance_matrix = np.stack(self.distance_matrix)
        self.distance_matrix = np.divide(1.0, self.distance_matrix, where=self.distance_matrix != 0)
        self.label = np.stack(self.label)

        print('adjacent matrix:\t', self.adjacent_matrix.shape)
        print('distance matrix:\t', self.distance_matrix.shape)
        print('node attribute matrix:\t', self.node_attr_matrix.shape)
        print('label_name:\t\t', self.label.shape)

    def __len__(self):
        return len(self.adjacent_matrix)

    def __getitem__(self, idx):
        adjacent_matrix = self.adjacent_matrix[idx]
        distance_matrix = self.distance_matrix[idx]
        node_attr_matrix = self.node_attr_matrix[idx]
        label = self.label[idx:idx+1]

        adjacent_matrix = torch.from_numpy(adjacent_matrix)
        distance_matrix = torch.from_numpy(distance_matrix)
        node_attr_matrix = torch.from_numpy(node_attr_matrix)
        label = torch.from_numpy(label)
        return adjacent_matrix, distance_matrix, node_attr_matrix, label


class GraphDataset_N_Gram_Random_Projection(Dataset):
    def __init__(self, data_path, n_gram_num, is_sum_up_feature_segment=False):
        if isinstance(data_path, str):
            data_path = [data_path]

        self.random_projected_matrix = []
        self.label = []

        for path in data_path:
            data = np.load(path)
            print(data.keys())
            random_projected_matrix = data['random_projected_list']
            print('original size\t', random_projected_matrix.shape)
            random_projected_matrix = random_projected_matrix[:, :n_gram_num, ...]
            molecule_num, _, random_projection_dimension, segmentation_num = random_projected_matrix.shape
            print('truncated size\t', random_projected_matrix.shape)
            if is_sum_up_feature_segment:
                random_projected_matrix = random_projected_matrix.sum(axis=-1)
                print('sum up along feature segment axis, shape is\t', random_projected_matrix.shape)
            else:
                random_projected_matrix = random_projected_matrix.reshape(
                    (molecule_num, n_gram_num, random_projection_dimension*segmentation_num),
                    order='F'
                )
                print('flatten feature segment with random projection, shape is\t', random_projected_matrix.shape)
            self.random_projected_matrix.extend(random_projected_matrix)
            self.label.extend(data['label_name'])

        self.random_projected_matrix = np.stack(self.random_projected_matrix)
        self.label = np.stack(self.label)

        print('random projected:\t', self.random_projected_matrix.shape)
        print('label_name:\t\t', self.label.shape)

    def __len__(self):
        return len(self.label)

    def __getitem__(self, idx):
        random_projected_matrix = self.random_projected_matrix[idx]
        label = self.label[idx:idx+1]

        random_projected_matrix = torch.from_numpy(random_projected_matrix)
        label = torch.from_numpy(label)
        return random_projected_matrix, label


class GraphDataset_N_Gram_Embedded_Projection(Dataset):
    def __init__(self, data_path, n_gram_num, label_name='label_name'):
        if isinstance(data_path, str):
            data_path = [data_path]

        self.random_projected_matrix = []
        self.label = []

        for path in data_path:
            data = np.load(path)
            print(data.keys())
            random_projected_matrix = data['random_projected_list']
            print('original size\t', random_projected_matrix.shape)
            random_projected_matrix = random_projected_matrix[:, :n_gram_num, ...]
            molecule_num, _, random_projection_dimension = random_projected_matrix.shape
            print('truncated size\t', random_projected_matrix.shape)

            random_projected_matrix = random_projected_matrix.reshape(
                (molecule_num, n_gram_num, random_projection_dimension),
                order='F'
            )
            print('flatten feature segment with random projection, shape is\t', random_projected_matrix.shape)
            self.random_projected_matrix.extend(random_projected_matrix)
            self.label.extend(data[label_name])

        self.random_projected_matrix = np.stack(self.random_projected_matrix)
        self.label = np.stack(self.label)

        print('random projected:\t', self.random_projected_matrix.shape)
        print('label_name:\t\t', self.label.shape)

    def __len__(self):
        return len(self.label)

    def __getitem__(self, idx):
        random_projected_matrix = self.random_projected_matrix[idx]
        label = self.label[idx:idx+1]

        random_projected_matrix = torch.from_numpy(random_projected_matrix)
        label = torch.from_numpy(label)
        return random_projected_matrix, label


class GraphDataset_N_Gram_Embedding(Dataset):
    def __init__(self, data_path, n_gram_num=4):
        if isinstance(data_path, str):
            data_path = [data_path]

        self.node_attr_matrix = []
        self.incidence_matrix_list = []
        self.label = []

        for path in data_path:
            data = np.load(path)
            self.node_attr_matrix.extend(data['node_attribute_matrix_list'])
            self.incidence_matrix_list.extend(data['incidence_matrix_list'][..., :n_gram_num])
            self.label.extend(data['label_name'])

        self.node_attr_matrix = np.stack(self.node_attr_matrix)
        self.incidence_matrix_list = np.stack(self.incidence_matrix_list)
        self.label = np.stack(self.label)

        print('node attribute matrix:\t', self.node_attr_matrix.shape)
        print('incidence matrix:\t', self.incidence_matrix_list.shape)
        print('label_name:\t\t', self.label.shape)

    def __len__(self):
        return len(self.label)

    def __getitem__(self, idx):
        node_attr_matrix = self.node_attr_matrix[idx]
        incidence_matrix = self.incidence_matrix_list[idx]
        label = self.label[idx:idx+1]

        node_attr_matrix = torch.from_numpy(node_attr_matrix)
        incidence_matrix = torch.from_numpy(incidence_matrix)
        label = torch.from_numpy(label)
        return node_attr_matrix, incidence_matrix, label


if __name__ == '__main__':
    dataset = GraphDataSet_Adjacent(['../datasets/keck_pria_lc/0_graph.npz', '../datasets/keck_pria_lc/1_graph.npz'])
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1024,
                                             shuffle=False,
                                             num_workers=4)

    for batch_id, (adjacent_matrix, node_attr_matrix, label) in enumerate(dataloader):
        print(batch_id, '\t', adjacent_matrix.size(), '\t', node_attr_matrix.size(), '\t', label.size())
        if batch_id >= 10:
            break
            
    dataset = GraphDataSet_Adjacent(['../datasets/keck_pria_lc/0_graph.npz', '../datasets/keck_pria_lc/1_graph.npz'])
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1024,
                                             shuffle=False,
                                             num_workers=4)

    for batch_id, (adjacent_matrix, distance_matrix, node_attr_matrix, label) in enumerate(dataloader):
        print(batch_id, '\t', adjacent_matrix.size(), '\t', distance_matrix.size(), '\t', node_attr_matrix.size(), '\t', label.size())
        if batch_id >= 10:
            break