import torch
import numpy as np
import torch.nn as nn

from itertools import product
from typing import List, Tuple, Union

from utils.miscellaneous import flatten
from utils.miscellaneous import replace_module

class RFRecorder:
    '''
        Hook module that is used to compute the Receptive Field (RF) of a given layer in
        a convolutional network.

        ** NOTE: For this module to work properly it is mandatory that the key list contains
        **       also the input layer with relevant RF parameters correctly initialized!
    '''

    convlike_l  = (nn.Conv2d, nn.MaxPool2d, nn.AvgPool2d)
    passlike_l  = (nn.AdaptiveAvgPool2d, nn.ReLU, nn.BatchNorm2d, nn.Linear)
    eraselike_l = (nn.ConvTranspose2d)

    avglike_l  = (nn.AvgPool2d, )

    def __init__(self, keys : List[str], input : Tuple[int, int, int]) -> None:
        self.keys  = np.array(keys)

        # Here we define the RF dictionary that will contain the RF info for each tracked layer
        # in a model. Each entries will itself be a simple dictionary storing the following key
        # * variable: n_feat | jump | rf_size | start
        self.RFs = {k : {} for k in self.keys}
        
        # Here we initialize the input layer parameter
        self.RFs[self.keys[0]]['jump'] = 1
        self.RFs[self.keys[0]]['start'] = .5 
        self.RFs[self.keys[0]]['rf_size'] = 1

        self.RFs[self.keys[0]]['rf_shape'] = input

    def __call__(self, module, inp, out) -> None:
        # Cycle through the keys to prepare for next layer
        # * NOTE: We roll the key list BEFORE processing this layer as first key entry is expected
        # *       to be the input layer with pre-filled relevant values.
        self.keys = np.roll(self.keys, -1)

        curr_l, prev_l = self.keys[[0, -1]]

        # Grab the previous layer RF params
        prev_jump = self.RFs[prev_l]['jump']
        prev_start = self.RFs[prev_l]['start']
        prev_rf_size = self.RFs[prev_l]['rf_size']

        # * Select action based on which module we are scanning
        # If layer acts as a convolution, we update the model RF parameters
        if isinstance(module, self.convlike_l):
            stride      = module.stride
            padding     = module.padding
            kernel_size = module.kernel_size

            dilation    = 1 if isinstance(module, self.avglike_l) else module.dilation

            stride, padding, kernel_size, dilation = map(self._check_reduce, [stride, padding, kernel_size, dilation])

            # Update the current key-layer RF parameter based on current and last convolution par
            self.RFs[curr_l]['jump'] = prev_jump * stride
            self.RFs[curr_l]['start'] = prev_start + ((kernel_size - 1) / 2 - padding) * prev_jump
            self.RFs[curr_l]['rf_size'] = prev_rf_size + ((kernel_size - 1) * dilation) * prev_jump

        # If we are passing through a BatchNorm or relu, we simply propagate the information
        elif isinstance(module, self.passlike_l):
            self.RFs[curr_l]['jump'] = prev_jump
            self.RFs[curr_l]['start'] = prev_start
            self.RFs[curr_l]['rf_size'] = prev_rf_size

        # If we are passing though a Deconvolution Layer, we reset the parameters
        elif isinstance(module, self.eraselike_l):
            self.RFs[curr_l]['jump'] = 0
            self.RFs[curr_l]['start'] = 0
            self.RFs[curr_l]['rf_size'] = 0

        else:
            msg = f'Encountered layer of unknown kind: {module}'
            raise ValueError(msg)

        self.RFs[curr_l]['rf_shape'] = out.shape

    def _check_reduce(self, maybe_tuple):
        if isinstance(maybe_tuple, (tuple, list)):
            assert len(maybe_tuple) == 2 and maybe_tuple[0] == maybe_tuple[1], f'{maybe_tuple}'
            
            maybe_tuple = maybe_tuple[0]

        return maybe_tuple

class ForwardRF:
    '''
        Simple class that makes use of the RFRecorder to extract the RFs of a given CNN.
        NOTE: This class DOES NOT behave well with multi-path architectures such as ResNets

        Parameters:
            - input : Tuple[C, W, H] = Tuple containing the Channel (C), Width (W) and Height (H)
                                       of the input that is fed to the network
            - keys  : List[str] = List containing the layer names. It will be used to index the RF
                                  dictionary. If it does not contain the input, one will be added.
        
        Keyword Args:
            - contains_inp [False] : bool = Flag to signal whether provided keys already contains the
                                            input key
    '''

    def __init__(self, keys : List[str], input : Tuple[int, int, int], contains_inp : bool = False) -> None:
        # Here we define the layer keys that will index the RFrecorder dictionary
        self.keys = keys if contains_inp else ['input'] + keys
        self.inp_key = self.keys[0]

        # Here we store the input dimension
        self.input = input

        # Here we initialize the receptive field recorder
        self.recorder = RFRecorder(self.keys, input)

        self.has_recorded = False

    def __call__(self, module : nn.Module, device : str, exclude : List[nn.Module] = None) -> dict:
        exclude = [type(None)] if exclude is None else exclude

        # * Here we register the RFRecorder as forward hook
        # Get a reference to the complete set of module layers. We do so by flattening the
        # module object to make sure that only leaf-like submodules are considered (ex: no
        # Sequential or nested modules would be considered)
        layers = flatten(module)

        # Register the forward hooks for each layered targeted as 'traced'
        hook_handles = [l.register_forward_hook(self.recorder) for l in layers if not isinstance(l, tuple(exclude))]

        # Here we create a mock-up input and send it to the network
        mock_inp = torch.zeros(1, *self.input).to(device)
        _ = module(mock_inp)

        # Once registration is completed, we remove the hooks and return the computed RFs
        for hook in hook_handles: hook.remove()
        self.has_recorded = True

        return self.recorder.RFs

    def get_unit_rf(self, key : str, unit_pos : Union[Tuple[int, int], List[Tuple[int, int]]] = None) -> np.ndarray:
        if not self.has_recorded:
            raise ValueError('Before requesting unit RF global model parameter should be computed. Call ReceptiveField on model instance.') 

        # Get the RF info for the selected layer
        RF = self.recorder.RFs[key]
        inp_shape = self.recorder.RFs[self.inp_key]['rf_shape'][-2:]

        # Get shapes of input and selected layer
        l_shape = RF['rf_shape'][-2:] 

        # Check whether a non-conv layer was requested. In that case simply return the input shape
        if len(RF['rf_shape']) <= 2:
            full_img = ((0, inp_shape[0]), (0, inp_shape[1])) 
            out = [full_img for _ in product(range(l_shape[0]), range(l_shape[1]))] if unit_pos is None else\
                 ([full_img] * len(unit_pos) if isinstance(unit_pos, (list, np.ndarray)) else full_img)

            return np.array(out)

        # If no unit is provided, we return the RF for all unit in the given layer
        unit_pos = [pos for pos in product(range(l_shape[0]), range(l_shape[1]))] if unit_pos is None else unit_pos
        unit_pos = unit_pos if isinstance(unit_pos, (list, np.ndarray)) else [unit_pos]

        # We check whether provided unit position are within selected feature map limits
        check = [pos[0] < 0 or pos[0] >= l_shape[0] or pos[1] < 0 or pos[1] >= l_shape[1] for pos in unit_pos]
        if np.any(check):
            raise ValueError(f'Invalid Unit positions {unit_pos} for layer {key} with features shape {l_shape}')

        # Use the layer RF parameters to reconstruct the unit RFs
        unit_RFs = [((RF['start'] + pos[0] * RF['jump'] - RF['rf_size'] / 2, RF['start'] + pos[0] * RF['jump'] + RF['rf_size'] / 2),
                     (RF['start'] + pos[1] * RF['jump'] - RF['rf_size'] / 2, RF['start'] + pos[1] * RF['jump'] + RF['rf_size'] / 2)) for pos in unit_pos]

        unit_RFs = [((max(0, rf[0][0]), min(inp_shape[0], rf[0][1])),
                     (max(0, rf[1][0]), min(inp_shape[1], rf[1][1]))) for rf in unit_RFs]

        return np.array(unit_RFs, dtype = np.int)

class BackwardRF:
    '''
        This class offers an alternative way of computing units' RF in a convolution network by exploiting the
        backward propagation of the gradient. The RF is define as the collection of all the input-level pixel
        for which the mock-up gradient is non-zero.

        NOTE: It is rather important to substitute eventual ReLU layers with Identities and MaxPools with
              AvgPools so to facilitate the gradient propagation (that would be halted by the max operation).
              To do so we provide an utility function called `replace_module`.
    '''

    def __init__(self, device : str = None) -> None:
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device

        self.outs   = []
        self.grads  = {}

        self.curr_l = -1

    def __call__(self, model : nn.Module, inp_shape, units_idx : np.ndarray, exclude = None) -> np.ndarray:
        exclude = [type(None)] if exclude is None else exclude
        if len(inp_shape) == 3: inp_shape = (1, *inp_shape)
        elif len(inp_shape) == 2: inp_shape = (1, 3, *inp_shape)
        else: raise ValueError(f'Unsupported input shape {inp_shape}')

        hook_handles = []

        # Here we replace all ReLUs and MaxPooling operation to operations that facilitates the gradient flow,
        # in particular the Identity layer for ReLUs and AvgPooling layer in place of MaxPools
        model = replace_module(model, nn.ReLU, nn.Identity)

        # Here we create a mock-up input
        self.inp_tensor = torch.randn(*inp_shape, requires_grad = True, device = self.device)
        hook_handles += [self.inp_tensor.register_hook(self._inp_hook)]

        layers = flatten(model)

        # Here we register a simple forward hook for collecting the forward module activation
        hook_handles += [l.register_forward_hook(self._act_hook) for l in layers if not isinstance(l, tuple(exclude))]

        # Zero out the model grad so that only RF-relevant grad are accumulated
        model.zero_grad()

        # Compute the model output to get the activation of all selected layers
        _ = model(self.inp_tensor)

        scale = 1e3

        for self.curr_l, (out, idxs) in enumerate(zip(self.outs, units_idx)):
            # Select given unit of current output and trigger a backward pass
            if   idxs.shape[1] == 3: iouts = out[0, idxs[:, 0], idxs[:, 1], idxs[:, 2]]
            elif idxs.shape[1] == 2: iouts = out[0, idxs[:, 0], idxs[:, 1]]
            elif idxs.shape[1] == 1: iouts = out[0, idxs[:, 0]]
            else: raise ValueError(f'Unknown unit indexing shape: {idxs.shape}') 

            for iout in iouts:
                # model.zero_grad()
                iout.backward(retain_graph = True)

        # Clean all the pending hooks
        for hook in hook_handles: hook.remove()

        # Collect and return all the computed grads
        grads = {k : np.vstack(v) for k, v in self.grads.items()}
        
        return 

    def _act_hook(self, module, inp, out):
        self.outs += [out]

    def _inp_hook(self, grad):
        if self.curr_l not in self.grads:
            self.grads[self.curr_l] = []

        self.grads[self.curr_l] += [grad.detach().cpu().numpy()]