import numpy as np

from tqdm import tqdm

from utils.information import H_XY, I_XY, H
from utils.information import bin_data_equipop
from utils.information import bin_data_equisize

from functools import partial
from multiprocessing import Pool

from typing import List, Tuple, Union

def iprofile(
        data : dict, 
        acts : dict,
        nbins : Union[int, Tuple[int, int]] = 20,  
        bin_strategy : str = 'equisize',
        bias : str = None,
        norm : str = None,
        keys : List[str] = None,
        **kwargs
        ):
        
        if isinstance(nbins, int): nbins = (nbins, nbins)
        
        valid_bs = ('equisize', 'equipop')

        def pickbin(bins):
            if bin_strategy == 'equisize':
                return partial(bin_data_equisize, nbins = bins)
            elif bin_strategy:
                return partial(bin_data_equipop,  nbins = bins)
            else:
                raise ValueError(f'Unknown binning strategy {bin_strategy}. Use one in {valid_bs}')
        
        # * Bin the provided external quantity
        # If bin is None it is assumed to be already binned
        if nbins[0]:
            bin_f = pickbin(nbins[0])
            binR = {layer : np.apply_along_axis(bin_f, 1, X[..., 0]) if X is not None else None for layer, X in data.items()}
            binG = {layer : np.apply_along_axis(bin_f, 1, X[..., 1]) if X is not None else None for layer, X in data.items()}
            binB = {layer : np.apply_along_axis(bin_f, 1, X[..., 2]) if X is not None else None for layer, X in data.items()}

        else:
            binR = {layer : X[..., 0] if X is not None else None for layer, X in data.items()}
            binG = {layer : X[..., 1] if X is not None else None for layer, X in data.items()}
            binB = {layer : X[..., 2] if X is not None else None for layer, X in data.items()}

        # * Bin the recorded unit activations
        if nbins[1]:
            bin_f = pickbin(nbins[1])
            binF = {layer : np.apply_along_axis(bin_f, 1, feat) if feat is not None else None for layer, feat in acts.items()}

        else:
            binF = acts

        # * Compute the relevant Information-Theoretic quantities
        # * for each channel separately and for each layer
        __H   = partial(H,    bias = bias)
        __Hxy = partial(H_XY, bias = bias, kind = 'x|y')
        __Hyx = partial(H_XY, bias = bias, kind = 'y|x')
        __HXY = partial(H_XY, bias = bias, kind = 'x,y')
        __Ixy = partial(I_XY, bias = bias, norm = norm, order = 'x,y')

        def _infoq(
                Xs : np.ndarray, 
                Ys : np.ndarray, 
                P : Pool, 
                chunk : int = 10
            ):
            info = {}

            if Xs is None or Ys is None: return None

            # Apply the computation of the various entropy along the unit axis
            info['Hx'] = list(P.imap(__H, Xs, chunk))
            info['Hy'] = list(P.imap(__H, Ys, chunk))

            info['Hx|y'] = list(P.starmap(__Hxy, zip(Xs, Ys), chunk))
            info['Hy|x'] = list(P.starmap(__Hyx, zip(Xs, Ys), chunk))
            info['Hx,y'] = list(P.starmap(__HXY, zip(Xs, Ys), chunk))
            info['Ix,y'] = list(P.starmap(__Ixy, zip(Xs, Ys), chunk))

            return info

        keys = list(data.keys()) if keys is None else keys

        with Pool() as P:
            H_r = {l : _infoq(binR[l], binF[l], P) for l in tqdm(keys, desc = 'H [R]', leave = False)}
            H_g = {l : _infoq(binG[l], binF[l], P) for l in tqdm(keys, desc = 'H [G]', leave = False)}
            H_b = {l : _infoq(binB[l], binF[l], P) for l in tqdm(keys, desc = 'H [B]', leave = False)}

        return H_r, H_g, H_b

def filtbest(
    feats : dict,
    actvs : dict,
    score : dict,
    nunit : int,
    nimgs : int
    ) -> Tuple[dict, dict]:
    '''
        This function implements a filter of both features and activations
        based on a global image score. It first ranks the images based on
        their score and then only retains a given number of images picking
        from the top ranking ones.
    '''
    
    # We rank the images based on average (across channel) score for each layer
    Isort = {l : np.argsort(s.mean(-1), axis = 1) for l, s in score.items()}
    
    # We rank the units based on their average score across both images and channels
    Usort = {l : np.argsort(s.mean((-2, -1))) for l, s in score.items()}
    
    # We then only select the top-ranking nunits and top-ranking nimgs
    f_feats = {l : np.take_along_axis(feats[l][Usort[l][-nunit:]], isort[Usort[l][-nunit:], -nimgs:].reshape(nunit, nimgs, 1), axis = 1)  
               for l, isort in Isort.items()}
    
    f_actvs = {l : np.take_along_axis(actvs[l][Usort[l][-nunit:]], isort[Usort[l][-nunit:], -nimgs:], axis = 1)  
               for l, isort in Isort.items()}
    
    return f_feats, f_actvs
    
def buildfilt(
    scores    : dict,
    min_score : float = -1e10,
    max_score : float = +1e10,
    thr       : int   = 0,
    kind      : str   = 'rand'
    ) -> Tuple[dict, dict]:
    '''
        This utility function builds a filter mask that can be used for
        filtering the images based a given quality score. Moreover it
        allows for retaining only a sub-population of units by discarding
        the ones achieving the lowest score for each layer. 
    '''

    layers = list(scores.keys())

    # Compute the number of images that passes both filters
    # ? NOTE: We consider the channel average
    mean_score = {l : S.mean(axis = -1) for l, S in scores.items()}

    pop_size, img_size = mean_score[layers[0]].shape

    # Compute the global condition mask: which (unit, image) pair pass the filter?
    img_cond1 = {l : (l_score > min_score) for l, l_score in mean_score.items()}
    img_cond2 = {l : (l_score < max_score) for l, l_score in mean_score.items()}

    img_cond = {l : np.logical_and(img_cond1[l], img_cond2[l]) for l in mean_score}
    
    # Count the number of images passing the filter for each unit
    img_pnum = {l : cond.sum(axis = -1) for l, cond in img_cond.items()}

    # * By knowing the lowest non-zero number of available images, we can equalize the image
    # * number across layers to have a fair comparison
    pnum = np.array(list(img_pnum.values()))
    psort = np.argsort(pnum, axis = 1) # Sort the units based on the number of images passed

    # Discard the first thr units achieving the lowest score 
    img_cond = {l : v[sort[thr:]] for (l, v), sort in zip(img_cond.items(), psort)}

    # Count the number of images and units that pass the filter
    img_num = np.take_along_axis(pnum, psort[:, thr:], 1)[5:].min()
    pop_num = pop_size - thr

    # We can then restrict each unit to only use that amount of images
    masks = {'min_score' : min_score, 'img_num' : img_num, 'pop_num' : pop_num}
    for layer in tqdm(scores, desc = 'Filtering', leave = False):
        # For each unit select the corresponding images passing the filter number
        def subfilt(img_mask, sub_num = -1):
                if not img_mask.any(): return img_mask.copy()

                try:
                        mask = np.zeros(img_size, dtype = bool)
                        sample = np.random.choice(img_size, size = sub_num, replace = False, p = img_mask / sum(img_mask))
                        mask[sample] = True

                        return mask
                except ValueError:
                        print(sum(img_mask))
                        assert False

        # Create a layer-specific mask that selects fair_num images from the ones that
        # pass the filter, specifically for each unit
        img_filt = partial(subfilt, sub_num = img_num)
        
        masks[layer] = np.apply_along_axis(img_filt, 1, img_cond[layer])

    # Return the computed mask and sorting array
    return masks, psort, pnum