import sys
import numpy as np
from tensorflow.keras.metrics import binary_accuracy
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ReduceLROnPlateau, LearningRateScheduler, Callback
from tensorflow.keras.losses import categorical_crossentropy, binary_crossentropy
from tensorflow.keras import backend as K


def cosine_decay_with_warmup(global_step,
                             learning_rate_base,
                             total_steps,
                             warmup_learning_rate=0.0,
                             warmup_steps=0,
                             hold_base_rate_steps=0):
  

    if total_steps < warmup_steps:
        raise ValueError('total_steps must be larger or equal to '
                         'warmup_steps.')
    learning_rate = 0.5 * learning_rate_base * (1 + np.cos(
        np.pi *
        (global_step - warmup_steps - hold_base_rate_steps
         ) / float(total_steps - warmup_steps - hold_base_rate_steps)))
    if hold_base_rate_steps > 0:
        learning_rate = np.where(global_step > warmup_steps + hold_base_rate_steps,
                                 learning_rate, learning_rate_base)
    if warmup_steps > 0:
        if learning_rate_base < warmup_learning_rate:
            raise ValueError('learning_rate_base must be larger or equal to '
                             'warmup_learning_rate.')
        slope = (learning_rate_base - warmup_learning_rate) / warmup_steps
        warmup_rate = slope * global_step + warmup_learning_rate
        learning_rate = np.where(global_step < warmup_steps, warmup_rate,
                                 learning_rate)
    return np.where(global_step > total_steps, 0.0, learning_rate)


class LinearScheduler(Callback):
    def __init__(self,
                 start_lr,
                 end_lr,
                 batch_per_epoch,
                 epochs,
                 warmup_epochs = 200,
                 verbose =0):

        self.start_lr = start_lr
        self.end_lr = end_lr

        self.batch_since_restart = 0
        
        self.warmup_batches = warmup_epochs *batch_per_epoch
        self.training_batches = (epochs-warmup_epochs) *batch_per_epoch
        self.warm = True
        self.next_restart = warmup_epochs

        self.verbose = verbose
        self.history = {}

    def clr(self):
        if self.warm:
            lr = self.batch_since_restart/self.warmup_batches * self.start_lr
            return lr
        '''Calculate the learning rate.'''
        coeff = self.batch_since_restart/self.training_batches
        lr = self.start_lr +coeff*(self.end_lr-self.start_lr)
        return lr
    def on_train_begin(self, logs={}):
        '''Initialize the learning rate to the minimum value at the start of training.'''
        logs = logs or {}
        K.set_value(self.model.optimizer.lr, 0)

    def on_batch_end(self, batch, logs={}):
        '''Record previous batch statistics and update the learning rate.'''
        logs = logs or {}
        self.history.setdefault('lr', []).append(K.get_value(self.model.optimizer.lr))
        for k, v in logs.items():
            self.history.setdefault(k, []).append(v)

        self.batch_since_restart += 1
        K.set_value(self.model.optimizer.lr, self.clr())
        #return self.clr()

    def on_epoch_end(self, epoch, logs={}):
        '''Check for end of current cycle, apply restarts when necessary.'''
        if self.verbose > 0 :
            print('\nBatch %05d: setting learning '
                  'rate to %s.' % (epoch, self.clr()))
        if epoch + 1 == self.next_restart:
            
            self.batch_since_restart = 0
            self.warm = False
            #self.best_weights = self.model.get_weights()    
    
class SGDRScheduler(Callback):
    def __init__(self,
                 min_lr,
                 max_lr,
                 steps_per_epoch,
                 lr_decay=1,
                 cycle_length=10,
                 warmup_epochs = 200,
                 mult_factor=2,
                 warm = False,
                 verbose =0):

        self.min_lr = min_lr
        self.max_lr = max_lr
        self.lr_decay = lr_decay

        self.batch_since_restart = 0
        self.next_restart = warmup_epochs

        self.steps_per_epoch = steps_per_epoch

        self.cycle_length = cycle_length
        self.mult_factor = mult_factor
        self.warm = warm
        self.warmup_epochs = warmup_epochs *self.steps_per_epoch
        self.verbose = verbose
        self.history = {}

    def clr(self):
        if self.warm:
            lr = self.batch_since_restart/self.warmup_epochs * self.max_lr
            return lr
        '''Calculate the learning rate.'''
        fraction_to_restart = self.batch_since_restart / (self.steps_per_epoch * self.cycle_length)
        lr = self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1 + np.cos(fraction_to_restart * np.pi))
        return lr

    def on_train_begin(self, logs={}):
        '''Initialize the learning rate to the minimum value at the start of training.'''
        logs = logs or {}
        K.set_value(self.model.optimizer.lr, 0)

    def on_batch_end(self, batch, logs={}):
        '''Record previous batch statistics and update the learning rate.'''
        logs = logs or {}
        self.history.setdefault('lr', []).append(K.get_value(self.model.optimizer.lr))
        for k, v in logs.items():
            self.history.setdefault(k, []).append(v)

        self.batch_since_restart += 1
        K.set_value(self.model.optimizer.lr, self.clr())
        #return self.clr()

    def on_epoch_end(self, epoch, logs={}):
        '''Check for end of current cycle, apply restarts when necessary.'''
        if self.verbose > 0 :
            print('\nBatch %05d: setting learning '
                  'rate to %s.' % (epoch, self.clr()))
        if epoch + 1 == self.next_restart:
            
            self.batch_since_restart = 0
            if not self.warm:
                
                self.cycle_length = np.ceil(self.cycle_length * self.mult_factor)
                self.max_lr *= self.lr_decay
                self.min_lr *= self.lr_decay
            print(epoch,self.cycle_length)
            self.next_restart += self.cycle_length
            self.warm = False
            #self.best_weights = self.model.get_weights()

class InvSGDRScheduler(Callback):
    def __init__(self,
                 min_lr,
                 max_lr,
                 steps_per_epoch,
                 lr_decay=1,
                 cycle_length=10,
                 warmup_epochs = 200,
                 mult_factor=2,
                 warm = False,
                 sawtooth=False,
                 verbose =0):

        self.min_lr = min_lr
        self.max_lr = max_lr
        self.lr_decay = lr_decay
        self.sawtooth =sawtooth
        self.batch_since_restart = 0
        self.next_restart = warmup_epochs

        self.steps_per_epoch = steps_per_epoch

        self.cycle_length = cycle_length
        self.mult_factor = mult_factor
        self.warm = warm
        self.warmup_epochs = warmup_epochs *self.steps_per_epoch
        self.verbose = verbose
        self.history = {}

    def clr(self):
        if self.warm:
            lr = self.batch_since_restart/self.warmup_epochs * self.max_lr
            return lr
        '''Calculate the learning rate.'''
        fraction_to_restart = self.batch_since_restart / (self.steps_per_epoch * self.cycle_length)
        if self.sawtooth :
            lr = self.min_lr + (self.max_lr - self.min_lr) * fraction_to_restart
        else :
            lr = self.max_lr - 0.5 * (self.max_lr - self.min_lr) * (1 + np.cos(fraction_to_restart * np.pi))
        return lr

    def on_train_begin(self, logs={}):
        '''Initialize the learning rate to the minimum value at the start of training.'''
        logs = logs or {}
        K.set_value(self.model.optimizer.lr, 0)

    def on_batch_end(self, batch, logs={}):
        '''Record previous batch statistics and update the learning rate.'''
        logs = logs or {}
        self.history.setdefault('lr', []).append(K.get_value(self.model.optimizer.lr))
        for k, v in logs.items():
            self.history.setdefault(k, []).append(v)

        self.batch_since_restart += 1
        K.set_value(self.model.optimizer.lr, self.clr())
        #return self.clr()

    def on_epoch_end(self, epoch, logs={}):
        '''Check for end of current cycle, apply restarts when necessary.'''
        if self.verbose > 0 :
            print('\nBatch %05d: setting learning '
                  'rate to %s.' % (epoch, self.clr()))
        if epoch + 1 == self.next_restart:
            
            self.batch_since_restart = 0
            if not self.warm:
                
                self.cycle_length = np.ceil(self.cycle_length * self.mult_factor)
                self.max_lr *= self.lr_decay
                self.min_lr *= self.lr_decay
            print(epoch,self.cycle_length)
            self.next_restart += self.cycle_length
            self.warm = False
            #self.best_weights = self.model.get_weights()
class TimeStepScheduler(Callback):
   
    def __init__(self,
                 nb_epochs,
                 batch_per_epoch,
                 warmup_epochs,
                 rates=[],
                 epoch_steps=[],
                 start_lr=0,
                 coeff_change = 0,
                 steps_change = 0,
                 verbose=0):


        super(TimeStepScheduler, self).__init__()
        if len(rates) ==0:
            nb_etapes  = (nb_epochs-warmup_epochs)//steps_change
            rates=[start_lr*coeff_change**i for i in range(nb_etapes) ]
            epoch_steps = [(i+1)*steps_change for i in range(nb_etapes) ]
        self.learning_rate_base = rates[0]
        self.total_steps = nb_epochs * batch_per_epoch
        self.global_step = 0
        self.batch_per_epoch = batch_per_epoch
        self.warmup_steps = warmup_epochs * batch_per_epoch
        self.verbose = verbose
        self.rates = rates
        self.epoch_steps = epoch_steps
        self.learning_rates = []

    def on_batch_end(self, batch, logs=None):
        self.global_step = self.global_step + 1
        lr = K.get_value(self.model.optimizer.lr)
        self.learning_rates.append(lr)

    def on_batch_begin(self, batch, logs=None):
        if self.global_step<self.warmup_steps:
            lr = self.global_step/self.warmup_steps * self.learning_rate_base
        else :
            i = 0
            while i<len(self.epoch_steps)-1 and self.global_step>self.epoch_steps[i]*self.batch_per_epoch:
                i=i+1
            lr = self.rates[i]
        K.set_value(self.model.optimizer.lr, lr)
        if self.verbose > 0 and self.global_step%self.batch_per_epoch == 0:
            print('\nBatch %05d: setting learning '
                  'rate to %s.' % (self.global_step + 1, lr))


class WarmUpCosineDecayScheduler(Callback):
    """Cosine decay with warmup learning rate scheduler
    """

    def __init__(self,
                 nb_epochs,
                 batch_per_epoch,
                 learning_rate_base,
                 warmup_epochs=0,
                 hold_base_rate_epochs=0,
                 warmup_learning_rate=0,
                 verbose=0):
        """Constructor for cosine decay with warmup learning rate scheduler.

    Arguments:
        learning_rate_base {float} -- base learning rate.
        total_steps {int} -- total number of training steps.

    Keyword Arguments:
        global_step_init {int} -- initial global step, e.g. from previous checkpoint.
        warmup_learning_rate {float} -- initial learning rate for warm up. (default: {0.0})
        warmup_steps {int} -- number of warmup steps. (default: {0})
        hold_base_rate_steps {int} -- Optional number of steps to hold base learning rate
                                    before decaying. (default: {0})
        verbose {int} -- 0: quiet, 1: update messages. (default: {0})
        """

        super(WarmUpCosineDecayScheduler, self).__init__()
        self.learning_rate_base = learning_rate_base
        self.batch_per_epoch=batch_per_epoch
        self.total_steps = nb_epochs * batch_per_epoch
        self.global_step = 0
        self.warmup_learning_rate = warmup_learning_rate
        self.warmup_steps = warmup_epochs * batch_per_epoch
        self.hold_base_rate_steps = hold_base_rate_epochs * batch_per_epoch
        self.verbose = verbose
        self.learning_rates = []

    def on_batch_end(self, batch, logs=None):
        self.global_step = self.global_step + 1
        lr = K.get_value(self.model.optimizer.lr)
        self.learning_rates.append(lr)

    def on_batch_begin(self, batch, logs=None):
        lr = cosine_decay_with_warmup(global_step=self.global_step,
                                      learning_rate_base=self.learning_rate_base,
                                      total_steps=self.total_steps,
                                      warmup_learning_rate=self.warmup_learning_rate,
                                      warmup_steps=self.warmup_steps,
                                      hold_base_rate_steps=self.hold_base_rate_steps)
        K.set_value(self.model.optimizer.lr, lr)
        if self.verbose > 0 and self.global_step%self.batch_per_epoch == 0:
            print('\nBatch %05d: setting learning '
                  'rate to %s.' % (self.global_step + 1, lr))

def scheduler(epoch, lr):
    if epoch < 20:
        return lr*1.22
    
    elif epoch < 180:
        return lr
    else :
        return lr * tf.math.exp(-0.08)


def LearningRateSchedulerMaxMin(lr_start=0.01,lr_end=0.0001, epochs= 100):
    lr_start = 1e-3
    lr_end = 1e-4
    lr_decay = (lr_end / lr_start)**(1. / epochs)
    return LearningRateScheduler(lambda e: lr_start * lr_decay ** e)
