from torch import Tensor
from torch.utils.data import Dataset


class SimpleDataset(Dataset):
    def __init__(
        self, 
        X: Tensor, 
        Y: Tensor, 
    ):
        """
        Inputs:
            X: Tensor [length, dim]
            Y: Tensor [length, 1]
        """
        self.X = X 
        self.Y = Y 

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]
        
    def __len__(self):
        return len(self.Y)