from tensorflow.keras.utils import Sequence
import numpy as np

EPS = 1e-8


class DOSBatchGenerator(Sequence):

    def __init__(self, X, R, config, batch_size, randomize=True):
        self.X = X
        self.R = R
        self.N = X.shape[0]
        self.L = X.shape[1]
        self.B = batch_size
        self.num_samples_per_epoch = config['samples_per_epoch']
        self.randomize = randomize
        self.seq_idxs = np.arange(self.N)
        self.num_batches = np.ceil(len(self.seq_idxs) / self.B)
        np.random.seed(42)

    def __len__(self):
        if self.randomize:
            return self.num_samples_per_epoch
        else:
            return int(self.num_batches)

    def __getitem__(self, idx):

        if self.randomize:
            seq_ids_in_batch = np.random.choice(self.N, (self.B,), replace=True)
            batch_x = np.concatenate([self.X[seq_ids_in_batch], self.R[seq_ids_in_batch][:, [0]]], axis=1)
            batch_y = self.R[seq_ids_in_batch]
        else:
            seq_ids_in_batch = np.arange(idx * self.B, idx * self.B + min(self.B, self.N - idx * self.B))
            batch_x = np.concatenate([self.X[seq_ids_in_batch], self.R[seq_ids_in_batch][:, [0]]], axis=1)
            if self.R is None:
                batch_y = np.zeros((len(seq_ids_in_batch), self.L))
            else:
                batch_y = self.R[seq_ids_in_batch]

        return batch_x, batch_y