import torch
import torch.nn as nn

def lecun_tanh(input):
    input = torch.tanh((2*input)/3)
    return 1.7159 * input

class LeCun_tanh(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return lecun_tanh(x)

class Elliot_auto(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input / (1+torch.abs(input))

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = None
        input, = ctx.saved_tensors

        if ctx.needs_input_grad[0]:
            grad_input = grad_output.clone()

            zero_indices = input == 0.0
            _indices = ~zero_indices

            grad_input[zero_indices] = 0.0
            grad_input[_indices] = grad_input[_indices] / (1+torch.abs(input[_indices]))**2
        return grad_input

class Elliot(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, input):
        return Elliot_auto.apply(input)
    
class ShiftedTanh(nn.Module):
    def __init__(self, tau):
        super().__init__()
        self.tanh = nn.Tanh()
        self.register_buffer('tau', torch.tensor(tau))
        self.register_buffer('sparse_value', nn.functional.tanh(torch.tensor(tau)))

    def forward(self, input):
        return self.tanh(input+self.tau) - self.sparse_value