import torch
from torch.autograd import Function


class ReswishBinarizeF(Function):
    r"""
        This :class:`torch.autograd.Function` implement Reswish binary approximation described in BNN+ sec. 3.2.
     The forward pass is sign function (1 if x >= 0 else -1).
     The backward pass is second derivative of Swish.

    """
    @staticmethod
    def forward(ctx, inputs, beta):
        ctx.beta = float(beta)
        ctx.save_for_backward(inputs)
        return inputs.sign()*inputs.abs().mean() + (inputs == 0).float() # *input.abs().mean()

    @staticmethod
    def backward(ctx, grad_output):
        beta = ctx.beta
        inputs, = ctx.saved_tensors
        scaled_inputs = beta * inputs
        return ((beta * (2 - scaled_inputs * torch.tanh(scaled_inputs / 2))) / (1 + torch.cosh(scaled_inputs))) * grad_output, None


reswish_binarize = ReswishBinarizeF.apply
