import torch
import torch.nn as nn
from .functions.fakequantn import fakequantn

class FakeQuantN(nn.Module):
    """
    Quantize a given value following Tensorflow Fake Quantization API.
    """
    def __init__(self, nbits, init_min=-6., init_max=6., ema_decay=0.999, quantize_at_beginning=False):
        """
        :param int nbits: Number of bits used for representing values.
        :param float init_min: original value for 'min' bound.
        :param float init_max: original value for 'max' bound.
        :param float ema_decay: Exponential moving averages decay.
        :param bool quantize_at_beginning: Quantize from the creation of the model or not.
        """
        super().__init__()

        self.register_buffer('running_mean_min', torch.tensor([init_min]))
        self.register_buffer('running_mean_max', torch.tensor([init_max]))
        self.register_buffer('ema_decay', torch.tensor([ema_decay]))
        self._nbits = nbits
        self._quantize = quantize_at_beginning

    def _ema(self, tensor, val):
        tensor.copy_(tensor * self.ema_decay + val * (1 - self.ema_decay))

    def forward(self, inputs):
        if self.training:
            self._ema(self.running_mean_min, torch.min(inputs.data))
            self._ema(self.running_mean_max, torch.max(inputs.data))

        if not self._quantize:
            return inputs
        else:
            q = fakequantn(inputs, torch.min(torch.tensor(0.).type_as(self.running_mean_min).to(inputs.device),
                                             self.running_mean_min.to(inputs.device)),
                           torch.max(torch.tensor(0.).type_as(self.running_mean_max).to(inputs.device),
                                     self.running_mean_max.to(inputs.device)), self._nbits)
            return q

    @property
    def nbits(self):
        return self._nbits

    @property
    def quantize(self):
        return self._quantize

    @quantize.setter
    def quantize(self, value):
        self._quantize = value

class FakeQuant8(FakeQuantN):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, nbits=8, **kwargs)
