from spaghettini import quick_register

from src.mains.task_getters import get_system_and_trainer


@quick_register
def train(cfg, exp_name, cfg_dir, tmp_dir):
    # Get the pytorch lightning system and trainer. Pick the most recent checkpoint.
    pl_system, trainer, checkpoint_found, load_ckpt_filepath = get_system_and_trainer(
        cfg=cfg,
        exp_name=exp_name,
        cfg_path=cfg_dir,
        tmp_dir=tmp_dir,
        load_checkpoint_type="recent")

    # Train or continue training.
    trainer.fit(pl_system, ckpt_path=load_ckpt_filepath)
