from tensorflow.keras.utils import Sequence
import numpy as np

EPS = 1e-8

class ESBatchGenerator(Sequence):

    def __init__(self, X, y,  config, batch_size, randomize=True):
        self.X = X
        self.y = y
        self.N = X.shape[0]
        self.L = X.shape[1]
        self.B = batch_size
        self.num_samples_per_epoch = config['samples_per_epoch']
        self.randomize = randomize
        self.seq_idxs = np.arange(self.N)
        self.num_batches = np.ceil(len(self.seq_idxs) / self.B)
        np.random.seed(42)


    def __len__(self):
        if self.randomize:
            return self.num_samples_per_epoch
        else:
            return int(self.num_batches)

    def __getitem__(self, idx):

        if self.randomize:
            seq_ids_in_batch = np.random.choice(self.N, (self.B,), replace=True)
            batch_x = self.X[seq_ids_in_batch]
            batch_t = np.tile(np.expand_dims(np.arange(self.L, dtype=float), axis=(0, 2)), (self.B, 1, 1)) / self.L
            batch_y = np.tile(np.expand_dims(self.y[seq_ids_in_batch], axis=1),(1, self.L))
        else:
            seq_ids_in_batch = np.arange(idx*self.B, idx*self.B + min(self.B, self.N - idx*self.B))
            batch_x = self.X[seq_ids_in_batch]
            batch_t = np.tile(np.expand_dims(np.arange(self.L, dtype=float), axis=(0, 2)), (len(seq_ids_in_batch), 1, 1)) / self.L
            batch_y = np.tile(np.expand_dims(self.y[seq_ids_in_batch], axis=1),(1, self.L)) 

        return [batch_t, batch_x], batch_y