"""Configurations for the gradient descent experiments."""
import itertools

from scripts.classification.gradient_descent import Config


def add_config(*, name, **kwargs):
    CONFIGS[name] = Config(name=name, **kwargs)


CONFIGS = {
    'fast': Config(
        name='fast',
        dataset='mnist49',
        n_components=8,
        m=4,
        N=350,
        lr=1e-3,
        n_steps=200_000,
        freeze_second_layer=True,
        random_subset=False,
    ),
}


_DATASETS = ['mnist49', 'fashion_mnist_pullover_coat']
_N_COMPONENTS = [8, 16]
_M = [4, 8, 16, 32, 64]
_N = [350, 700]

for (dataset, n_components, m, N) in itertools.product(_DATASETS, _N_COMPONENTS, _M, _N):
    name = f'{dataset}_d{n_components}_m{m}_N{N}'
    add_config(
        name=name,
        dataset=dataset,
        n_components=n_components,
        m=m,
        N=N,
        #
        n_steps=1_000_000,
        lr=1e-3,
        freeze_second_layer=True,
        #
        random_subset=False,
    )
