import torch
from unnq.quantization.linear import ProxGradLinear, DualAveragingLinear
from unnq.quantization.conv2d import ProxGradConv2d, DualAveragingConv2d


def flatten_quant_weights_and_grads(method, model, prox_op):
    """ Returns a list of flattened weight and gradient tensors for quantization
    layers corresponding to a particular proximal operator.
    .
    Args:
        method: string, either 'pg' (Proximal Gradient) or 'da' (Dual Averaging)
        model: unnq.quantization.quant.Quantization, the quantization model
        prox_op: unnq.prox.x, proximal operator for which we want to obtain 
                 weights and gradients
    """

    if method == 'da':
        return flatten_quant_weights_and_grads_da(model, prox_op)
    else:
        return flatten_quant_weights_and_grads_pg(model, prox_op)


def flatten_quant_weights_and_grads_da(model, prox_op):
    weight_list = []
    grad_list = []

    for m in model.modules():
        if isinstance(m, DualAveragingLinear) or isinstance(m, DualAveragingConv2d) and (m.prox_op is prox_op):
            weight_list.append(prox_op(m.weight).reshape(-1))
            grad_list.append(m.weight.grad.reshape(-1))
    
    return weight_list, grad_list


def flatten_quant_weights_and_grads_pg(model, prox_op):
    weight_list = []
    grad_list = []

    for m in model.modules():
        if isinstance(m, ProxGradLinear) or isinstance(m, ProxGradConv2d) and (m.prox_op is prox_op):
            weight_list.append(m.weight.reshape(-1))
            grad_list.append(m.weight.grad.reshape(-1))
    
    return weight_list, grad_list


def compute_midpoints(w):
    """ Given a 1d tensor of size n, computes all n - 1 midpoints.
    For example, should return torch.tensor([1.0, 2.1]) for input
    w = torch.tensor([0.0, 2.0, 2.2]).
    """

    if len(w.shape) > 1:
        raise ValueError('Input can only be 1d tensor.')

    with torch.no_grad():
        midpoints = torch.empty(w.shape[0] - 1)
        for i in range(midpoints.shape[0]):
            midpoints[i] = (w[i] + w[i + 1]) / 2.0

        return midpoints

def compute_weight_dist(method, quant):
    """ Computes the weight distribution; which weights are quantized where after projection
    """

    weight_list = []
    for p in quant.prox_ops:
        weight_list_p = []
        for m in quant.model.modules():
            if isinstance(m, (ProxGradLinear,
                              DualAveragingLinear,
                              ProxGradConv2d,
                              DualAveragingConv2d)):
                if p == m.prox_op:
                    weight_list_p.append(p(m.weight).reshape(-1, 1))
        
        weight_list.append(torch.unique(torch.cat(weight_list_p, dim=0), return_counts=True))
    return weight_list

def compute_quant_dist(method, quant, threshold=1e-6):
    """ Computes the quantization distribution; which weights are actually quantized
    """

    quant_list = []
    for p in quant.prox_ops:
        quant_list_p = {}
        for w in p._weights:
            quant_list_p[w.cpu().item()] = 0
        n_total = 0
        for m in quant.model.modules():
            if isinstance(m, (ProxGradLinear,
                              DualAveragingLinear,
                              ProxGradConv2d,
                              DualAveragingConv2d)):
                if p == m.prox_op:
                    for w in p._weights:
                        m_weight = m.weight.reshape(-1, 1) 
                        n_total += m_weight.shape[0]
                        bool_tensor = torch.abs(m_weight - w) < threshold
                        quant_list_p[w.cpu().item()] += torch.sum(bool_tensor)

        quant_list.append((n_total // len(p._weights), quant_list_p))
    return quant_list
