from torch.autograd import Function

class STE(Function):
    @staticmethod
    def forward(ctx, inputs, prox_op):
        return prox_op(inputs.sign()*inputs.abs().mean() + (inputs == 0).float())

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None
