class AnnealingScheduler:
    def __init__(
            self, 
            *args, 
            **kwargs
        ):
        self.args = args
        self.kwargs = kwargs
        self.coef = 1.0

    def __call__(self, *args, **kwargs):
        return self.coef

    def step(self):
        raise NotImplementedError

class CyclicalAnnealingScheduler(AnnealingScheduler):
    def __init__(
            self, 
            cycle_len,
            min_coef = 0.0, 
            max_coef = 1.0, 
            cut_step = None,
            warmup_steps = 0,
            *args, 
            **kwargs
        ):
        super(CyclicalAnnealingScheduler, self).__init__(*args, **kwargs)
        self.cycle_len = cycle_len
        self.min_coef = min_coef
        self.max_coef = max_coef
        self.cut_step = cut_step
        self.warmup_steps = warmup_steps

        if self.cut_step is None:
            self.cut_step = cycle_len // 2

        self.coef_dif = self.max_coef - self.min_coef

        self.coef = self.min_coef

        self.step_count = 0

    def step(self):

        if self.step_count >= self.warmup_steps:
            mod_step = self.step_count % self.cycle_len
            
            if mod_step < self.cut_step:
                self.coef = self.min_coef + mod_step * self.coef_dif / self.cut_step
            else:
                self.coef = self.max_coef
            
        print(f'[AnnealingScheduler] coef: {self.coef}')
        self.step_count += 1