import torch
from torch.autograd import Function


class ReswishBinarizeF2(Function):
    r"""
        This :class:`torch.autograd.Function` implement modified version of Reswish binary approximation described in BNN+ sec. 3.2.
     The forward pass is signSwish function.
     The backward pass is second derivative of Swish.

    """
    @staticmethod
    def forward(ctx, inputs, beta):
        ctx.beta = float(beta)
        ctx.save_for_backward(inputs)
        output = 2 * torch.sigmoid(beta * inputs) * (1 + beta * inputs * (1 - torch.sigmoid(beta * inputs)))-1
        return output

    @staticmethod
    def backward(ctx, grad_output):
        beta = ctx.beta
        inputs, = ctx.saved_tensors
        scaled_inputs = beta * inputs
        return ((beta * (2 - scaled_inputs * torch.tanh(scaled_inputs / 2))) / (1 + torch.cosh(scaled_inputs))) * grad_output, None


reswish_binarize2 = ReswishBinarizeF2.apply
