import torch
import torch.nn.functional as F
import torch.nn as nn
import dgl.function as fn



def correct(g, y_soft, y_true, mask,num_correction_layers=500,correction_alpha=0.01, correction_adj="DAD",autoscale=True, scale=1.0):
    with g.local_scope():
        # assert abs(float(y_soft.sum()) / y_soft.size(0) - 1.0) < 1e-2
        numel = (
            int(mask.sum()) if mask.dtype in [torch.bool,torch.uint8] else mask.size(0)
        )
        assert y_true.size(0) == numel

        prop1 = LabelPropagation(
            num_correction_layers, correction_alpha, correction_adj
        )

        if y_true.dtype == torch.long:
            y_true = F.one_hot(y_true.view(-1), y_soft.size(-1)).to(
                y_soft.dtype
            )

        error = torch.zeros_like(y_soft)
        error[mask] =  y_soft[mask] - y_true


        if autoscale:
            smoothed_error = prop1(
                g, error, post_step=lambda x: x.clamp_(-1.0, 1.0)
            )
            sigma = error[mask].abs().sum() / numel
            scale = sigma / smoothed_error.abs().sum(dim=1, keepdim=True)
            scale[scale.isinf() | (scale > 1000)] = 1.0

            # result = y_soft + scale * smoothed_error
            # result[result.isnan()] = y_soft[result.isnan()]
            return smoothed_error,scale*smoothed_error  #result
        else:

            def fix_input(x):
                x[mask] = error[mask]
                return x

            smoothed_error = prop1(g, error, post_step=fix_input)


            return smoothed_error,scale*smoothed_error #result


class LabelPropagation(nn.Module):
    r"""

    Description
    -----------
    Introduced in `Learning from Labeled and Unlabeled Data with Label Propagation <https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.14.3864&rep=rep1&type=pdf>`_

    .. math::
        \mathbf{Y}^{\prime} = \alpha \cdot \mathbf{D}^{-1/2} \mathbf{A}
        \mathbf{D}^{-1/2} \mathbf{Y} + (1 - \alpha) \mathbf{Y},

    where unlabeled data is inferred by labeled data via propagation.

    Parameters
    ----------
        num_layers: int
            The number of propagations.
        alpha: float
            The :math:`\alpha` coefficient.
        adj: str
            'DAD': D^-0.5 * A * D^-0.5
            'DA': D^-1 * A
            'AD': A * D^-1
    """

    def __init__(self, num_layers, alpha, adj="DAD"):
        super(LabelPropagation, self).__init__()

        self.num_layers = num_layers
        self.alpha = alpha
        self.adj = adj

    @torch.no_grad()
    def forward(
        self, g, labels, mask=None, post_step=lambda y: y.clamp_(0.0, 1.0)
    ):
        with g.local_scope():
            if labels.dtype == torch.long:
                labels = F.one_hot(labels.view(-1)).to(torch.float32)

            y = labels
            if mask is not None:
                y = torch.zeros_like(labels)
                y[mask] = labels[mask]

            last = (1 - self.alpha) * y
            degs = g.in_degrees().float().clamp(min=1)
            norm = (
                torch.pow(degs, -0.5 if self.adj == "DAD" else -1)
                .to(labels.device)
                .unsqueeze(1)
            )

            for _ in range(self.num_layers):
                # Assume the graphs to be undirected
                if self.adj in ["DAD", "AD"]:
                    y = norm * y

                g.ndata["h"] = y
                g.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
                y = self.alpha * g.ndata.pop("h")

                if self.adj in ["DAD", "DA"]:
                    y = y * norm

                y = post_step(last + y)

            return y