from mayo.log import log
from mayo.override import util
from mayo.override.base import Parameter
from mayo.override.quantize.base import QuantizerBase
import tensorflow as tf
import numpy as np


class TernaryQuantizer(QuantizerBase):
    """
    Ternary quantization, quantizes all values into the range:
        {- 2^base * scale, 0, 2^base * scale}.

    Args:
        base: The universal coarse-grain scaling factor
              applied to tenary weights.
    References:
        - Extremely Low Bit Neural Network: Squeeze the Last Bit Out with ADMM
        - Trained Ternary Quantization
    """
    base = Parameter('base', 0, [], 'int', trainable=False)
    exponent_bias = Parameter('exponent_bias', 0, [], 'int')

    def __init__(
            self, session, base=None, stochastic=None,
            should_update=True, enable=True):
        super().__init__(session, should_update, enable)
        if base is not None:
            if base < 0:
                raise ValueError(
                    'Base of ternary quantization must be '
                    'greater or equal than 0.')
            self.base = base
        if stochastic is not None:
            raise NotImplementedError(
                'Ternary quantization does not implement stochastic mode.')

    def _quantize(self, value, base=None):
        base = util.cast(self.base if base is None else base, int)
        shift = util.cast(2 ** (base - self.exponent_bias), float)
        positives = util.cast(value > 0, float)
        negatives = util.cast(value < 0, float)
        return positives * shift - negatives * shift

    def _apply(self, value):
        with tf.variable_scope(self._scope, reuse=tf.AUTO_REUSE):
            self.scale = tf.get_variable(
                name='scale', shape=[], dtype=tf.float32,
                initializer=tf.constant_initializer(1.0),
                trainable=True)
        return self._quantize(value) * self.scale

    def _info(self):
        base = int(self.eval(self.base))
        if hasattr(self, 'scale'):
            scale = float(self.eval(self.scale))
            return self._info_tuple(width=2, base=base, scale=scale)
        elif hasattr(self, 'channel_scale'):
            scale = self.eval(self.channel_scale)
            mean = np.mean(scale)
            std = np.std(scale)
            return self._info_tuple(
                width=2, base=base, mean=mean, std=std)
        else:
            raise ValueError('Missing scale in {} quantizer'.format(self.name))

    def find_shift_exp(self, value):
        # hard-code a exponent range for now
        max_exponent = int(2 ** 8)
        for exp in range(min(-max_exponent, -4), max(max_exponent, 10)):
            max_value = 2 ** (exp + 1)
            overflows = util.logical_or(
                value < -max_value, value > max_value)
            # no overflow
            if self._overflow_rate(overflows) <= 0.0:
                break
        return exp


    def _update(self):
        max_exponent = self.find_shift_exp(self.eval(self.before))
        self.exponent_bias = max_exponent



class ChannelTernaryQuantizer(TernaryQuantizer):
    """Same tenary quantization, but channel-wise scaling factors.  """

    def _apply(self, value):
        # Instantiate the trainable scaler
        with tf.variable_scope(self._scope, reuse=tf.AUTO_REUSE):
            self.channel_scale = tf.get_variable(
                name='channel_scale', shape=[int(value.shape[-1])], 
                dtype=tf.float32,
                initializer=tf.ones_initializer(),
                trainable=True)
        return self._quantize(value) * self.channel_scale
