import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Union


class StraightThrough(nn.Module):
    def __init__(self, channel_num: int = 1):
        super().__init__()

    def forward(self, input):
        return input


def round_ste(x: torch.Tensor):
    """
    Implement Straight-Through Estimator for rounding operation.
    """
    return (x.round() - x).detach() + x


def lp_loss(pred, tgt, p=2.0, reduction='none'):
    """
    loss function measured in L_p Norm
    """
    if reduction == 'none':
        return (pred-tgt).abs().pow(p).sum(1).mean()
    else:
        return (pred-tgt).abs().pow(p).mean()


class AsymDynamicQuantizer(nn.Module):
    """
    Dynamic quantizer for asymmetric quantization of activation
    """
    def __init__(self, n_bits, dim=None, fix_abits=False):
        super(AsymDynamicQuantizer, self).__init__()
        self.dim = dim
        self.n_bits = n_bits
        self.n_levels = 2.**self.n_bits
        self.fix_abits = fix_abits

    def forward(self, x: torch.Tensor):
        if self.n_bits == 32: # do not quant when 32 bit act
            return x
        else:
            # calculate the quantization range with (min,max)
            # quantization range (min, max)
            if self.dim is None:
                # layer-wise quantization
                x_max = x.max()
                x_min = x.min()
            else:
                # channel-wise quantization
                x_max = x.amax(self.dim, keepdim=True)
                x_min = x.amin(self.dim, keepdim=True)
            # get quantization range & center
            q_range = x_max - x_min
            zero_point = x_min

            # scale & shift the weight to make sure quantized values are integer with distance 1
            q_step = q_range / (self.n_levels - 1)
            q_zero_point = (zero_point / q_step).round()
            # quantize the weight with round function
            x_q = x / q_step
            x_q.round_()
            # centralize the quantized value to "0"
            x_q = x_q - q_zero_point
            # clamp the weight values to have (2^xbits) levels
            x_q.clamp_(0, self.n_levels - 1)
            # shift center point
            x_q = x_q + q_zero_point
            # return dequantized out
            return x_q * q_step

    def set_n_bits(self, n_bits):
        if not self.fix_abits:
            self.n_bits = n_bits
            self.n_levels = 2.**self.n_bits

    def extra_repr(self):
        s = 'bit={n_bits}, fix_abits={fix_abits}, dim={dim}'
        return s.format(**self.__dict__)


class UniformAffineQuantizer(nn.Module):
    """
    PyTorch Function that can be used for asymmetric quantization (also called uniform affine
    quantization). Quantizes its argument in the forward pass, passes the gradient 'straight
    through' on the backward pass, ignoring the quantization that occurred.
    Based on https://arxiv.org/abs/1806.08342.

    :param n_bits: number of bit for quantization
    :param symmetric: if True, the zero_point should always be 0
    :param channel_wise: if True, compute scale and zero_point in each channel
    :param scale_method: determines the quantization scale and zero point
    """
    def __init__(self, n_bits: int = 8, symmetric: bool = False, channel_wise: bool = False, scale_method: str = 'max',
            leaf_param: bool = False, always_zero: bool = False, act_quant_mode: str = 'qdiff'):
        super(UniformAffineQuantizer, self).__init__()
        self.sym = symmetric
        # assert 2 <= n_bits <= 8, 'bitwidth not supported'
        self.n_bits = n_bits
        self.n_levels = 2 ** self.n_bits if not self.sym else 2 ** (self.n_bits - 1) - 1
        self.delta = None
        self.zero_point = None
        self.inited = False
        self.leaf_param = leaf_param
        self.channel_wise = channel_wise
        self.scale_method = scale_method
        self.running_stat = False
        self.always_zero = always_zero
        if self.leaf_param:
            self.x_min, self.x_max = None, None
        # for debugging
        self.init_count = 0

    def forward(self, x: torch.Tensor):
        #return x

        if self.inited is False:
            # print("forward")
            if self.leaf_param:
                # print("leaf_param")
                delta, self.zero_point = self.init_quantization_scale(x, self.channel_wise)
                self.delta = torch.nn.Parameter(delta)
                # self.zero_point = torch.nn.Parameter(self.zero_point)
            else:
                # print("non leaf_param")
                self.delta, self.zero_point = self.init_quantization_scale(x, self.channel_wise)
            self.inited = True

        if self.running_stat:
            self.act_momentum_update(x)

        # start quantization
        x_int = round_ste(x / self.delta) + self.zero_point
        if self.sym:
            x_quant = torch.clamp(x_int, -self.n_levels - 1, self.n_levels)
        else:
            x_quant = torch.clamp(x_int, 0, self.n_levels - 1)
        x_dequant = (x_quant - self.zero_point) * self.delta
        if torch.isnan(self.delta).any() or torch.isinf(self.delta).any():
            import pdb;pdb.set_trace()
        return x_dequant

    def gen_delta_before_load_calibrated_model(self, x: torch.Tensor):
        delta, zero_point = None, None
        if self.channel_wise:
            num_c = x.size(0)
            delta = torch.zeros(num_c, dtype=x.dtype, device=x.device)
            zero_point = torch.zeros(num_c, dtype=x.dtype, device=x.device)
            if len(x.shape) == 4:
                self.delta = delta.view(-1, 1, 1, 1)
                self.zero_point = zero_point.view(-1, 1, 1, 1)
            elif len(x.shape) == 3:
                self.delta = delta.view(-1, 1, 1)
                self.zero_point = zero_point.view(-1, 1, 1)
            else:
                self.delta = delta.view(-1, 1)
                self.zero_point = zero_point.view(-1, 1)
        else:
            self.delta = torch.tensor(0, dtype=x.dtype, device=x.device)
            self.zero_point = torch.tensor(0, dtype=x.dtype, device=x.device)

        if self.leaf_param:
            self.delta = torch.nn.Parameter(self.delta)

        self.inited = True


    def init_quantize(self, x: torch.Tensor):
        if self.inited is False:
            print("init_quant")
            if self.leaf_param:
                delta, self.zero_point = self.init_quantization_scale(x, self.channel_wise)
                self.delta = torch.nn.Parameter(delta)
                # self.zero_point = torch.nn.Parameter(self.zero_point)
            else:
                self.delta, self.zero_point = self.init_quantization_scale(x, self.channel_wise)
            self.inited = True

    def init_quantization_scale(self, x: torch.Tensor, channel_wise: bool = False):
        # print("run init_quantization_scale")
        # print(f"init_counter: {self.init_count}")
        self.init_count += 1
        # NOTE: slow init due to the recursive init for channel-_wise qunat
        delta, zero_point = None, None
        if channel_wise:
            x_clone = x.clone().detach()
            n_channels = x_clone.shape[0]
            if len(x.shape) == 4:
                x_max = x_clone.abs().max(dim=-1)[0].max(dim=-1)[0].max(dim=-1)[0]
            elif len(x.shape) == 3:
                x_max = x_clone.abs().max(dim=-1)[0].max(dim=-1)[0]
            else:
                x_max = x_clone.abs().max(dim=-1)[0]
            delta = x_max.clone()
            zero_point = x_max.clone()
            # determine the scale and zero point channel-by-channel
            for c in range(n_channels):
                delta[c], zero_point[c] = self.init_quantization_scale(x_clone[c], channel_wise=False)
            if len(x.shape) == 4:
                delta = delta.view(-1, 1, 1, 1)
                zero_point = zero_point.view(-1, 1, 1, 1)
            elif len(x.shape) == 3:
                delta = delta.view(-1, 1, 1)
                zero_point = zero_point.view(-1, 1, 1)
            else:
                delta = delta.view(-1, 1)
                zero_point = zero_point.view(-1, 1)
        else:
            if self.leaf_param:
                self.x_min = x.data.min()
                self.x_max = x.data.max()

            if 'max' in self.scale_method:
                x_min = min(x.min().item(), 0)
                x_max = max(x.max().item(), 0)
                if 'scale' in self.scale_method:
                    x_min = x_min * (self.n_bits + 2) / 8
                    x_max = x_max * (self.n_bits + 2) / 8

                x_absmax = max(abs(x_min), x_max)
                if self.sym:
                    # x_min, x_max = -x_absmax if x_min < 0 else 0, x_absmax
                    delta = x_absmax / self.n_levels
                else:
                    delta = float(x.max().item() - x.min().item()) / (self.n_levels - 1)
                # delta = float(x_max - x_min) / (self.n_levels - 1)
                if delta < 1e-8:
                    warnings.warn('Quantization range close to zero: [{}, {}]'.format(x_min, x_max))
                    delta = 1e-8

                zero_point = round(-x_min / delta) if not (self.sym or self.always_zero) else 0
                delta = torch.tensor(delta).type_as(x)

            elif self.scale_method == 'mse':
                x_max = x.max()
                x_min = x.min()
                best_score = 1e+10
                for i in range(80):
                    new_max = x_max * (1.0 - (i * 0.01))
                    new_min = x_min * (1.0 - (i * 0.01))
                    x_q = self.quantize(x, new_max, new_min)
                    # L_p norm minimization as described in LAPQ
                    # https://arxiv.org/abs/1911.07190
                    score = lp_loss(x, x_q, p=2.4, reduction='all')
                    if score < best_score:
                        best_score = score
                        delta = (new_max - new_min) / (2 ** self.n_bits - 1) \
                            if not self.always_zero else new_max / (2 ** self.n_bits - 1)
                        zero_point = (- new_min / delta).round() if not self.always_zero else 0
            else:
                raise NotImplementedError

        return delta, zero_point

    def quantize(self, x, max, min):
        delta = (max - min) / (2 ** self.n_bits - 1) if not self.always_zero else max / (2 ** self.n_bits - 1)
        zero_point = (- min / delta).round() if not self.always_zero else 0
        # we assume weight quantization is always signed
        x_int = torch.round(x / delta)
        x_quant = torch.clamp(x_int + zero_point, 0, self.n_levels - 1)
        x_float_q = (x_quant - zero_point) * delta
        return x_float_q

    def bitwidth_refactor(self, refactored_bit: int):
        assert 2 <= refactored_bit <= 8, 'bitwidth not supported'
        self.n_bits = refactored_bit
        self.n_levels = 2 ** self.n_bits

    def extra_repr(self):
        s = 'bit={n_bits}, scale_method={scale_method}, symmetric={sym}, channel_wise={channel_wise},' \
            ' leaf_param={leaf_param}'
        return s.format(**self.__dict__)


class QuantModule(nn.Module):
    """
    Quantized Module that can perform quantized convolution or normal convolution.
    To activate quantization, please use set_quant_state function.
    """
    def __init__(self, org_module: Union[nn.Conv2d, nn.Conv1d, nn.Linear], weight_quant_params: dict = {},
                 act_quant_params: dict = {}, disable_act_quant: bool = False, fix_abits = False):
        super(QuantModule, self).__init__()
        self.weight_quant_params = weight_quant_params
        self.act_quant_params = act_quant_params
        self.aq_dim = None
        self.fix_abits = fix_abits
        if isinstance(org_module, nn.Conv2d):
            self.fwd_kwargs = dict(stride=org_module.stride, padding=org_module.padding,
                                   dilation=org_module.dilation, groups=org_module.groups)
            self.fwd_func = F.conv2d
            self.aq_dim = (1, 2, 3) #act [N, Cin, H, W]
        elif isinstance(org_module, nn.Conv1d):
            self.fwd_kwargs = dict(stride=org_module.stride, padding=org_module.padding,
                                   dilation=org_module.dilation, groups=org_module.groups)
            self.fwd_func = F.conv1d
            self.aq_dim = (1, 2) #act [N, Cin, L]
        else:
            self.fwd_kwargs = dict()
            self.fwd_func = F.linear
            self.aq_dim = -1 #act [N, *, Hin]
        self.weight = org_module.weight
        self.org_weight = org_module.weight.data.clone()
        if org_module.bias is not None:
            self.bias = org_module.bias
            self.org_bias = org_module.bias.data.clone()
        else:
            self.bias = None
            self.org_bias = None
        # de-activate the quantized forward default
        self.use_weight_quant = False
        self.use_act_quant = False
        self.act_quant_mode = act_quant_params['act_quant_mode']
        self.disable_act_quant = disable_act_quant
        # initialize quantizer
        self.weight_quantizer = UniformAffineQuantizer(**weight_quant_params)
        if self.act_quant_mode == 'qdiff':
            self.act_quantizer = UniformAffineQuantizer(**act_quant_params)
        elif self.act_quant_mode == 'dynamic':
            self.act_quantizer = AsymDynamicQuantizer(self.act_quant_params['n_bits'], dim=self.aq_dim, fix_abits=self.fix_abits)
        self.split = 0

        self.activation_function = StraightThrough()
        self.ignore_reconstruction = False

        self.org_extra_repr = org_module.extra_repr

        # if true, self.weight saves the dequantized weight, not the original weight
        self.weight_dequantized = False
        self.dequant_weight = None

    def forward(self, input: torch.Tensor, split: int = 0):
        if split != 0 and self.split != 0:
            assert(split == self.split)
        elif split != 0:
            #logger.info(f"split at {split}!")
            print(f"split at {split}!")
            self.split = split
            self.set_split()

        # activation quantization
        if self.use_act_quant and (not self.disable_act_quant):
            if self.act_quant_mode == 'qdiff':
                print('act_quant_mode is qdiff!')
                if self.split != 0:
                    input_0 = self.act_quantizer(input[:, :self.split, :, :])
                    input_1 = self.act_quantizer_0(input[:, self.split:, :, :])
                    input = torch.cat([input_0, input_1], dim=1)
                else:
                    input = self.act_quantizer(input)
            elif self.act_quant_mode == 'dynamic':
                input = self.act_quantizer(input)

        # weight quantization
        if self.use_weight_quant:
            if self.weight_dequantized:
                weight = self.dequant_weight
            else:
                if self.split != 0:
                    weight_0 = self.weight_quantizer(self.weight[:, :self.split, ...])
                    weight_1 = self.weight_quantizer_0(self.weight[:, self.split:, ...])
                    weight = torch.cat([weight_0, weight_1], dim=1)
                else:
                    weight = self.weight_quantizer(self.weight)
            bias = self.bias
        else:
            weight = self.org_weight
            bias = self.org_bias
            #weight = self.weight
            #bias = self.bias

        # do forward
        out = self.fwd_func(input, weight, bias, **self.fwd_kwargs)
        out = self.activation_function(out)

        return out

    def dequantize_weight(self):
        # weight quantization
        if self.use_weight_quant and (not self.weight_dequantized):
            if self.split != 0:
                weight_0 = self.weight_quantizer(self.weight[:, :self.split, ...])
                weight_1 = self.weight_quantizer_0(self.weight[:, self.split:, ...])
                weight = torch.cat([weight_0, weight_1], dim=1)
            else:
                weight = self.weight_quantizer(self.weight)
            self.dequant_weight = weight
            self.weight_dequantized = True

    def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False):
        self.use_weight_quant = weight_quant
        self.use_act_quant = act_quant

    def set_split(self):
        self.weight_quantizer_0 = UniformAffineQuantizer(**self.weight_quant_params)
        if self.act_quant_mode == 'qdiff':
            self.act_quantizer_0 = UniformAffineQuantizer(**self.act_quant_params)

    def set_running_stat(self, running_stat: bool):
        if self.act_quant_mode == 'qdiff':
            self.act_quantizer.running_stat = running_stat
            if self.split != 0:
                self.act_quantizer_0.running_stat = running_stat

    def extra_repr(self):
        s = self.org_extra_repr()
        s += ', use_weight_quant={use_weight_quant}, use_act_quant={use_act_quant}, act_quant_mode={act_quant_mode}, split={split}'.format(**self.__dict__)
        return s
