import tensorflow as tf

from mayo.util import object_from_params


def custom_batch_norm(
        tensor, decay=0.999, center=True, scale=False, weight_scale=False,
        epsilon=0.001, is_training=True, trainable=True,
        overriders=None, session=None, **kwargs):
    channels = tensor.shape[-1]
    scope = 'BatchNorm'
    overriders = overriders or {}

    def apply_overrider(value, key):
        overrider = overriders.get(key)
        if not overrider:
            return value
        oscope = '{}/{}'.format(scope, key)
        return overrider.apply(None, oscope, tf.get_variable, value)

    def get_variable(name, one_init=False, scalar=False, train=True):
        if one_init:
            initializer = tf.ones_initializer()
        else:
            initializer = tf.zeros_initializer()
        shape = [] if scalar else [channels]
        with tf.variable_scope(scope):
            return tf.get_variable(
                name, shape, initializer=initializer,
                trainable=train and trainable, dtype=tf.float32)

    tensor = apply_overrider(tensor, 'input')

    mean = get_variable('moving_mean', train=False)
    var = get_variable('moving_variance', one_init=True, train=False)
    if is_training:
        axes = list(range(tensor.shape.ndims - 1))
        bmean, bvar = tf.nn.moments(tensor, axes=axes)
        mean_op = tf.assign(mean, decay * mean + (1 - decay) * bmean)
        batches = tf.cast(tf.shape(tensor)[0], tf.float32)
        uvar = bvar * batches / (batches - 1)
        var_op = tf.assign(var, decay * var + (1 - decay) * uvar)
        tf.add_to_collections(tf.GraphKeys.UPDATE_OPS, mean_op)
        tf.add_to_collections(tf.GraphKeys.UPDATE_OPS, var_op)
        mean, var = bmean, bvar
    gamma = get_variable('gamma', one_init=True)
    beta = get_variable('beta')
    invstd = tf.rsqrt(var + epsilon)
    scale = gamma * invstd if scale else invstd
    offset = -mean * scale
    if center:
        offset += beta
    if weight_scale:
        scale *= get_variable(
            'weight_scale', one_init=True, scalar=True)
    scale = apply_overrider(scale, 'scale')
    offset = apply_overrider(offset, 'offset')
    return tensor * scale + offset
