from spaghettini import quick_register
from functools import partial

from src.data.data_loading.multi_split_loaders import get_multi_split_dataloaders
from src.data.datasets.prefix_sum.prefix_sum import PrefixSumDataset

ROOT = "./data"


@quick_register
def get_default_loaders_for_prefix_sum(partition_name: str, batch_size: int):
    mixture_mode = "mixture" if partition_name == "train" else "separate"
    split_ids = [32] if partition_name == "train" else [16, 32, 64, 128, 256, 512, 1024]
    shuffle = True if partition_name == "train" else False

    def prefix_sum_getter(dataset_partition_name, split_id):
        return PrefixSumDataset(root=ROOT, partition_name=dataset_partition_name, num_bits=split_id, download=False)

    return get_multi_split_dataloaders(
        dataset_getter=prefix_sum_getter,
        partition_name=partition_name,
        mixture_mode=mixture_mode,
        batch_size=batch_size,
        split_ids=split_ids,
        num_workers_per_dataset=4,
        shuffle=shuffle
    )
