import numpy as np
import torch
from torch.nn.functional import softplus

def asinh(x):
    return torch.log(x + (x ** 2 + 1) ** 0.5)


def acosh(x):
    return torch.log(x + (x ** 2 - 1) ** 0.5)


def atanh(x):
    return 0.5 * torch.log((1 + x) / (1 - x))


def logit(x):
    """
    This is equivalent to the inverse of nn.functional.sigmoid()
    :param x:
    :return:
    """
    return torch.log(x / (1 - x))


def sigmoid_inv(x):
    return logit(x)


def softplus2(x, exponent=2):
    return softplus(x) ** exponent


def softplus2_inv(x, exponent=2):
    return torch.log(torch.exp(torch.pow(x, 1 / exponent)) - 1.)


def softplus_inv(x):
    return torch.log(torch.exp(x) - 1.)


def diff(a, axis=0):
    """
    equivalent of np.diff with n=1
    magic from: https://stackoverflow.com/a/42612608
    :param a:
    :param axis:
    :return:
    """
    if axis == 0:
        return a[1:] - a[:-1]
    elif axis == 1:
        return a[:, 1:] - a[:, :-1]


def mnlp(actual_mean, pred_mean, pred_var):
    """
    Mean Negative Log Probability
    :param actual_mean:
    :param pred_mean:
    :param pred_var:
    :return:
    """
    log_part = torch.log(pred_var) + torch.log(2 * torch.tensor(np.pi))
    unc_part = ((actual_mean - pred_mean) / torch.sqrt(pred_var))**2
    summed_parts = 0.5 * (log_part + unc_part)
    return torch.sum(summed_parts)


def nlml_chol_fast(y_trn,
                   r,
                   rb_solve,
                   n,
                   m,
                   alpha,
                   beta):
    """

    :param y_trn:
    :param r:
    :param rb_solve:
    :param n:
    :param m:
    :param alpha:
    :param beta:
    :return:
    """
    e1_chol = torch.norm(y_trn, p="fro", keepdim=False, ) ** 2
    e2_chol = torch.norm(rb_solve, p="fro", keepdim=False) ** 2
    e_chol = (- beta / 2.0) * (e1_chol - e2_chol)

    logdet_chol = (1.0 / 2.0) * torch.sum(torch.log(torch.diagonal(r) ** 2))
    p1_chol = (m / 2.0) * torch.log(alpha / beta)
    p2_chol = (n / 2.0) * torch.log((2.0 * np.pi) / beta)
    nlml_chol = -1.0 * (e_chol - logdet_chol + p1_chol - p2_chol)
    return nlml_chol
