import logging
import torch
import torch.nn as nn
import torch.utils.data as data
from antgine.core import flatten_module
from antgine.metrics.utils import AverageMeter
from antgine.callback import Callback
from antgine.modules.quantization import QLinear, QConv2d, FakeQuantN, FakeQuantNBatchNormLinear, FakeQuantNBatchNormConv2d


class QuantizeCallback(Callback):
    def __init__(self, model: nn.Module, train_loader: data.DataLoader,
                 quantize_at_step: int, bnfreeze: int, bn_stats_freeze: int, hook=False):
        """
        :param torch.nn.Module model: Model.
        :param torch.utils.data.DataLoader dataloader: Training loader.
        :param int quantize_at_step: At which iteration quantization should be activated.
        :param int bnfreeze: At which iteration BatchNormalization parameters should be frozen.
        :param int bn_stats_freeze: At which iteration BatchNormalization statistics should be frozen.
        :param bool hook: If true, compute a min/max averages through all the training set instead of EMA.
        """
        super().__init__()
        self._model = model
        self._steps = 0
        self._quantize_at_step = quantize_at_step
        self._bnfreeze = bnfreeze
        self._bn_stats_freeze = bn_stats_freeze
        self._hook = hook
        self._train_loader = train_loader

    def on_forward_begin(self, epoch: int, i: int, xs: torch.Tensor, ys: torch.Tensor):
        """
            Quantize the network given at which steps we should quantize.
            If not hook mode then we use the exponential moving average of the min and max for each layer to quantize.
            If hook mode on then we go through the dataset and compute average of min and max (eg. useful for pretrained network)
            If bnfreeze on it will freeze the batchnorm parameter gamma and beta they won't get updated anymore.
        """
        # TODO REFACTOR VERY UGLY..
        self._steps += 1
        if self._steps == self._quantize_at_step:
            if self._hook:
                qn = list(filter(lambda l: isinstance(l, FakeQuantN), flatten_module(self._model)))
                statistics = dict(list(map(lambda m: (m, (AverageMeter(), AverageMeter())), qn)))
                def _populate_statistics(m, o):
                    statistics[m][0].update(o.data.min())
                    statistics[m][1].update(o.data.max())
                hooks = list(map(lambda l: l.register_forward_hook(lambda m, i, o: _populate_statistics(m, o)), qn))

                with torch.no_grad():
                    for i, (xs, _) in enumerate(self._train_loader):
                        xs = xs.to(next(self._model.parameters()).device) # most likely wont work for big models we would need the dataparallel model
                        # but hooks don't work with dataparallel model so if it raise an issue just change xs = xs.to('cuda:0') to xs = xs[:32}.to('cuda:0') eg.
                        _ = self._model(xs)

                list(map(lambda h: h.remove(), hooks))

            for module in self._model.modules():
                if isinstance(module, FakeQuantN):
                    module.quantize = True
                    if self._hook:
                        module.running_mean_min.fill_(statistics[module][0].avg)
                        module.running_mean_max.fill_(statistics[module][1].avg)
                    logging.info('Quantization of %s' % type(module))
                elif isinstance(module, QLinear) or isinstance(module, QConv2d):
                    module.quantize = True
                    logging.info('Quantization of %s' % type(module))            
        if self._steps == self._bn_stats_freeze:
            for module in self._model.modules():
                if isinstance(module, FakeQuantNBatchNormLinear) or isinstance(module, FakeQuantNBatchNormConv2d):
                    module.bn_stats_freeze = True
                    logging.info('BN Stats freeze %s' % module)

        if self._steps == self._bnfreeze:
            for module in self._model.modules():
                if isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d):
                    module.weight.requires_grad = False
                    module.bias.requires_grad = False
                    logging.info('BN Param freeze %s' % module)
                elif isinstance(module, FakeQuantNBatchNormLinear) or isinstance(module, FakeQuantNBatchNormConv2d):
                    module.gamma.requires_grad = False
                    module.beta.requires_grad = False
                    logging.info('BN Param freeze %s' % module)
