from sklearn.preprocessing import StandardScaler
import torch
from copy import deepcopy
from torch.utils.data import Dataset, Subset, ConcatDataset



class SubDataset(Dataset):
    def __init__(self, x, type, SEQ_LEN, LABEL_LEN, PRED_LEN):
        normalize = False

        self.x = x

        if normalize:
            self.x = StandardScaler().fit_transform(self.x.reshape(-1, 1)).reshape(-1)

        self.type = type

        self.SEQ_LEN = SEQ_LEN
        self.LABEL_LEN = LABEL_LEN
        self.PRED_LEN = PRED_LEN
        self.WINDOW_LENGTH = SEQ_LEN + PRED_LEN

        self.LEN = self.x.shape[0] - self.WINDOW_LENGTH + 1

    def __getitem__(self, index):
        s_begin = index
        s_end = s_begin + self.SEQ_LEN
        r_begin = s_end
        r_end = r_begin + self.PRED_LEN

        x = torch.Tensor(self.x[s_begin:s_end]).reshape(-1, 1)
        y = torch.Tensor(self.x[r_begin:r_end]).reshape(-1, 1)

        return x, y, self.type

    def __len__(self):
        return self.LEN


class PowerDataset(object):
    def __init__(self, data, split_ratio=None, SEQ_LEN=96, LABEL_LEN=48, PRED_LEN=7):
        if split_ratio is None:
            split_ratio = [0.6, 0.2, 0.2]
        self.split_ratio = split_ratio
        self.data = deepcopy(data)
        self.total_x = self.data.total_x
        self.types = self.data.node_attr
        del self.data.total_x
        del self.data.node_attr

        self.SEQ_LEN = SEQ_LEN
        self.LABEL_LEN = LABEL_LEN
        self.PRED_LEN = PRED_LEN
        self.WINDOW_LENGTH = SEQ_LEN + PRED_LEN

        self.LEN = self.total_x.shape[-1] - self.WINDOW_LENGTH + 1

    def get_dataset(self):
        train_ds = []
        valid_ds = []
        test_ds = []

        N_users = self.total_x.shape[0]
        N = self.LEN

        indexes = list(range(N))

        train_length = int(self.split_ratio[0]*N)
        valid_length = int(self.split_ratio[1]*N)

        train_indexes = indexes[:train_length]
        valid_indexes = indexes[train_length:train_length+valid_length]
        test_indexes = indexes[train_length+valid_length:]

        for i in range(N_users):
            dataset = SubDataset(x=self.total_x[i], type=self.types[i], SEQ_LEN=self.SEQ_LEN, LABEL_LEN=self.LABEL_LEN, PRED_LEN=self.PRED_LEN)
            sub_train_ds = Subset(dataset=dataset, indices=train_indexes)
            sub_valid_ds = Subset(dataset=dataset, indices=valid_indexes)
            sub_test_ds = Subset(dataset=dataset, indices=test_indexes)

            train_ds.append(sub_train_ds)
            valid_ds.append(sub_valid_ds)
            test_ds.append(sub_test_ds)

        train_ds = ConcatDataset(train_ds)
        valid_ds = ConcatDataset(valid_ds)
        test_ds = ConcatDataset(test_ds)

        return train_ds, valid_ds, test_ds