import numpy as np
import tensorflow as tf

from mayo.override import util
from mayo.override.base import Parameter
from mayo.override.quantize.base import QuantizerBase


class ShiftQuantizer(QuantizerBase):
    width = Parameter('width', 8, [], 'int')
    # FIXME exponent_bias name is for backward compatibility
    exponent_bias = Parameter('exponent_bias', 16, [], 'int')

    def __init__(
            self, session, width=None, bias=None,
            overflow_rate=None, stochastic=False,
            asymmetry=True, has_zero=True,
            should_update=True, enable=True):
        super().__init__(
            session=session, should_update=should_update, enable=enable)
        self.width = width
        self.exponent_bias = bias
        self.overflow_rate = overflow_rate
        self.stochastic = stochastic
        self.asymmetry = asymmetry
        self.has_zero = has_zero
        self.max_value = None

    @property
    def bias(self):
        return self.exponent_bias

    @bias.setter
    def bias(self, value):
        self.exponent_bias = value

    def _quantize(self, value, width=None, bias=None):
        bias = util.cast(self.bias if bias is None else bias, float)
        width = self.width if width is None else width
        descriminator = (2.0 ** (-bias)) / 2.0
        if self.has_zero:
            # sign, which can be -1, 0 and 1
            sign = util.cast(value > descriminator, int)
            sign -= util.cast(value < -descriminator, int)
            value = util.abs(value)
        else:
            # -1 and 1, no 0
            sign = util.cast(value >= 0, int)
            sign -= util.cast(value < 0, int)
            value = util.max(util.abs(value), descriminator)
        if self.stochastic:
            rounder = lambda v: util.stochastic_round(v, self.stochastic)
        else:
            rounder = util.round
        # FIXME strange bug that is not present in float quantizer
        # requiring us to use small non-zero value here,
        # even though NaN-values are eliminated below.
        exponent = rounder(util.log(util.max(value, descriminator), 2))
        # clip exponent and quantize mantissa
        exponent_min = -bias
        exponent_max = 2 ** util.cast(self.width, float) - 1 - bias
        exponent = util.clip_by_value(exponent, exponent_min, exponent_max)
        if self.asymmetry:
            # asymmetric clipping as we preserve the smallest value
            # -2^exponent_max to represent 0
            negative_exponent = util.min(exponent, exponent_max - 1)
            exponent = util.where(sign >= 0, exponent, negative_exponent)
        self.sign, self.exponent = sign, exponent
        # re-represent value
        return util.cast(sign, float) * (2.0 ** exponent)

    def _find_exponent(self, value, orate=None, max_bound=None):
        # FIXME asymmetric shift may require
        # a different overflow handling for negative numbers
        value = np.abs(value).flatten()
        value = value[value != 0]
        min_value = np.min(value)
        max_value = np.max(value)
        if max_value <= 0:
            raise ValueError(
                'Unable to determine exponent, input is a zero tensor.')
        min_exponent = int(np.floor(np.log2(min_value)))
        max_exponent = int(np.ceil(np.log2(max_value)))
        if orate is not None and orate == 0:
            # return max_exponent directly if we do not allow overflows
            return max_exponent
        for e in range(min_exponent, max_exponent + 1):
            b = 2 ** e
            if max_bound is not None:
                if max_bound < b:
                    return e
            elif orate is not None:
                overflows = util.logical_or(value < -b, value > b)
                if self._overflow_rate(overflows) <= orate:
                    return e
        raise ValueError(
            'This is not reachable, please make sure you have '
            'a correct orate or max_bound.')

    def _find_bias(self, value, width, max_value=None):
        max_value = max_value if max_value is not None else self.max_value
        levels = 2 ** width
        orate = self.overflow_rate
        if orate is None:
            # automatic overflow rate
            orate = 1 / (levels + 1)
        max_exp = self._find_exponent(value, orate, max_value)
        return 2 ** width - 1 - max_exp

    def _update(self):
        value = self.eval(self.before)
        width = self.eval(self.width)
        self.bias = self._find_bias(value, width)

    def search(self, params):
        max_bound = params.get('max')
        if max_bound is None:
            raise ValueError(
                'Require max value to search for {}.'.format(self))
        targets = params.get('targets')
        if targets is None or 'exponent_bias' not in targets:
            raise ValueError('Required targets are not specified.')
        width = self.eval(self.width)
        bias = self._find_bias(params['avg'][0], width, max_bound)
        # pick the one that has smallest quantization loss
        selected_targets = {'exponent_bias': bias}
        return selected_targets

    def _dump(self):
        data = {
            'width': self.width,
            'bias': self.bias,
            'sign': self.sign,
            'exponent': self.exponent,
        }
        data = self.session.run(data)
        # the original exponent is the real exponent, adjust it by bias
        data['exponent'] += data['bias']
        data['exponent'] = np.int32(data['exponent'])
        return data

    def _info(self):
        width, bias = self.session.run([self.width, self.bias])
        return self._info_tuple(width=int(width), bias=int(bias))


class ShiftScaleQuantizer(ShiftQuantizer):
    def _apply(self, value):
        # axes = [i for i in range(len(value.shape))]
        # norm_mean, norm_var = tf.nn.moments(value, axes=axes)
        # value = value / tf.sqrt(norm_var)
        value = super()._apply(value)
        # instantiate the trainable scaler
        with tf.variable_scope(self._scope, reuse=tf.AUTO_REUSE):
            self.scale = tf.get_variable(
                name='scale', shape=[], dtype=tf.float32,
                initializer=tf.constant_initializer(1.0),
                trainable=True)
        return value * self.scale

    def _info(self):
        width = int(self.eval(self.width))
        bias = int(self.eval(self.bias))
        scale = float(self.eval(self.scale))
        return self._info_tuple(width=width, bias=bias, scale=scale)
