from jax.experimental.optimizers import optimizer, make_schedule

import jax.numpy as np
"""
Optimizer library based heavily on jax.experimental.optimizers, with few small modifications.
"""

@optimizer
def momentum(step_size, mass=0.9, wd=0.):
    """
    Construct optimizer triple for SGD with momentum.
    Args:
        step_size: positive scalar, or a callable representing a step size schedule
            that maps the iteration index to positive scalar.
    	mass: positive scalar representing the momentum coefficient.
        wd: scales weights down by fraction, missing factor of 2 from L2 regularization
    Returns:
        An (init_fun, update_fun, get_params) triple.
    """
    step_size = make_schedule(step_size)
    def init(x0):
        v0 = np.zeros_like(x0)
        return x0, v0

    def update(i, g, state):
        x, velocity = state
        g = g + wd * x
        velocity = mass * velocity + g
        x = x - step_size(i) * velocity
        return x, velocity

    def get_params(state):
        x, _ = state
        return x

    return init, update, get_params

@optimizer
def adam(step_size, b1=0.9, b2=0.999, wd=0., eps=1e-8):
    """
    Construct optimizer triple for Adam.
    Args:
    step_size: positive scalar, or a callable representing a step size schedule
        that maps the iteration index to positive scalar.
    b1: optional, a positive scalar value for beta_1, the exponential decay rate
        for the first moment estimates (default 0.9).
    b2: optional, a positive scalar value for beta_2, the exponential decay rate
        for the second moment estimates (default 0.999).
    wd: optional, nonnegative scalar value.
    eps: optional, a positive scalar value for epsilon, a small constant for
        numerical stability (default 1e-8).

    Returns:
    An (init_fun, update_fun, get_params) triple.
    """
    step_size = make_schedule(step_size)
    def init(x0):
        m0 = np.zeros_like(x0)
        v0 = np.zeros_like(x0)
        return x0, m0, v0
    def update(i, g, state):
        x, m, v = state
        g = g + wd * x # adds weight decay
        m = (1 - b1) * g + b1 * m  # First  moment estimate.
        v = (1 - b2) * np.square(g) + b2 * v  # Second moment estimate.
        mhat = m / (1 - np.asarray(b1, m.dtype) ** (i + 1))  # Bias correction.
        vhat = v / (1 - np.asarray(b2, m.dtype) ** (i + 1))
        x = x - step_size(i) * mhat / (np.sqrt(vhat) + eps)
        return x, m, v
    def get_params(state):
        x, _, _ = state
        return x
    return init, update, get_params


