"""Poincare model utils functions."""

import torch

from utils.math_fns import arctanh, tanh

MIN_NORM = 1e-15
BALL_EPS = {torch.float32: 4e-3, torch.float64: 1e-5}


# ################# HYP OPS ########################

def egrad2rgrad(p, dp, c=1.0):
    lambda_p = lambda_x(p, c)
    dp /= lambda_p.pow(2)
    return dp

def expmap0(u, c=1.0):
    sqrt_c = c ** 0.5
    u_norm = u.norm(dim=-1, p=2, keepdim=True).clamp_min(MIN_NORM)
    gamma_1 = tanh(sqrt_c * u_norm) * u / (sqrt_c * u_norm)
    return project(gamma_1, c)


def lambda_x(x, c=1.0):
    x_sqnorm = torch.sum(x.data.pow(2), dim=-1, keepdim=True)
    return 2 / (1. - c * x_sqnorm).clamp_min(MIN_NORM)


def inner(x, u, v=None, c=1.0):
    if v is None:
        v = u
    lx = lambda_x(x, c)
    return lx ** 2 * (u * v).sum(dim=-1, keepdim=True)


def gyration(u, v, w, c=1.0):
    u2 = u.pow(2).sum(dim=-1, keepdim=True)
    v2 = v.pow(2).sum(dim=-1, keepdim=True)
    uv = (u * v).sum(dim=-1, keepdim=True)
    uw = (u * w).sum(dim=-1, keepdim=True)
    vw = (v * w).sum(dim=-1, keepdim=True)
    c2 = c ** 2
    a = -c2 * uw * v2 + c * vw + 2 * c2 * uv * vw
    b = -c2 * vw * u2 - c * uw
    d = 1 + 2 * c * uv + c2 * u2 * v2
    return w + 2 * (a * u + b * v) / d.clamp_min(MIN_NORM)


def ptransp(x, y, u, c=1.0):
    lx = lambda_x(x, c)
    ly = lambda_x(y, c)
    return gyration(y, -x, u, c) * lx / ly


def expmap(u, p, c=1.0):
    sqrt_c = c ** 0.5
    u_norm = u.norm(dim=-1, p=2, keepdim=True).clamp_min(MIN_NORM)
    second_term = tanh(sqrt_c / 2 * lambda_x(p, c) * u_norm * u / (sqrt_c * u_norm))
    gamma_1 = mobius_add(p, second_term, c)
    return gamma_1


def logmap0(y, c=1.0):
    sqrt_c = c ** 0.5
    y_norm = y.norm(dim=-1, p=2, keepdim=True).clamp_min(MIN_NORM)
    return y / y_norm / sqrt_c * arctanh(sqrt_c * y_norm)


def project(x, c=1.0):
    norm = x.norm(dim=-1, p=2, keepdim=True).clamp_min(MIN_NORM)
    eps = BALL_EPS[x.dtype]
    maxnorm = (1 - eps) / (c ** 0.5)
    cond = norm > maxnorm
    projected = x / norm * maxnorm
    return torch.where(cond, projected, x)


def mobius_add(x, y, c=1.0):
    x2 = torch.sum(x * x, dim=-1, keepdim=True)
    y2 = torch.sum(y * y, dim=-1, keepdim=True)
    xy = torch.sum(x * y, dim=-1, keepdim=True)
    num = (1 + 2 * c * xy + c * y2) * x + (1 - c * x2) * y
    denom = 1 + 2 * c * xy + c ** 2 * x2 * y2
    return num / denom.clamp_min(MIN_NORM)


def mobius_mul(x, t, c=1.0):
    sqrt_c = c ** 0.5
    normx = x.norm(dim=-1, p=2, keepdim=True).clamp_min(MIN_NORM)
    return tanh(t * arctanh(sqrt_c * normx)) * x / (normx * sqrt_c)


def get_midpoint(x, y, c=1.0):
    """Computes hyperbolic midpoint beween x and y."""
    t1 = mobius_add(-x, y, c)
    t2 = mobius_mul(t1, 0.5, c)
    return mobius_add(x, t2, c)


def get_midpoint_o(x, c=1.0):
    """
    Computes hyperbolic midpoint beween x and the origin
    """
    return mobius_mul(x, 0.5, c)


# ################# HYP DISTANCES ########################


def hyp_distance(x, y, c=1.0):
    """
    x: hyperbolic queries (B x d)
    y: hyperbolic candidates (B x d)
    c: hyperbolic curvature (1)
    return: B x 1 matrix with hyperbolic distances
    """
    sqrt_c = c ** 0.5
    x2 = torch.sum(x * x, dim=-1, keepdim=True)
    y2 = torch.sum(y * y, dim=-1, keepdim=True)
    xy = torch.sum(x * y, dim=-1, keepdim=True)
    c1 = 1 - 2 * c * xy + c * y2
    c2 = 1 - c * x2
    num = torch.sqrt((c1 ** 2) * x2 + (c2 ** 2) * y2 - (2 * c1 * c2) * xy)
    denom = 1 - 2 * c * xy + c ** 2 * x2 * y2
    pairwise_norm = num / denom.clamp_min(MIN_NORM)
    dist = arctanh(sqrt_c * pairwise_norm)
    return 2 * dist / sqrt_c


def pairwise_hyp_distance(x, c=1.0):
    sqrt_c = c ** 0.5
    x2 = torch.sum(x * x, dim=-1, keepdim=True)
    y2 = torch.sum(x * x, dim=-1, keepdim=True).transpose(0, 1)
    xy = x @ x.transpose(0, 1)
    c1 = 1 - 2 * c * xy + c * y2
    c2 = 1 - c * x2
    sqnum = (c1 ** 2) * x2 + (c2 ** 2) * y2 - (2 * c1 * c2) * xy
    num = torch.sqrt(sqnum.clamp_min(0))
    denom = 1 - 2 * c * xy + c ** 2 * x2 * y2
    pairwise_norm = num / denom.clamp_min(MIN_NORM)
    dist = arctanh(sqrt_c * pairwise_norm)
    return 2 * dist / sqrt_c


def dist_o(x, c=1.0):
    """
    Computes hyperbolic distance between x and the origin.
    """
    sqrt_c = c ** 0.5
    x_norm = x.norm(dim=-1, p=2, keepdim=True).clamp_min(MIN_NORM)
    return 2 * arctanh(sqrt_c * x_norm) / sqrt_c

def hyp_dist_o(x, c=1.0):
    """
    Computes hyperbolic distance between x and the origin.
    """
    x_norm = x.norm(dim=-1, p=2, keepdim=True)
    return 2 * arctanh(x_norm)


# ################# HYP LCA ########################


def hyp_lca(x, y, return_coord=True):
    """
    Computes hyperbolic LCA between two points.
    """
    eps = 1e-10
    x_norm = x.norm(dim=-1, p=2, keepdim=True).clamp_min(eps)
    y_norm = y.norm(dim=-1, p=2, keepdim=True).clamp_min(eps)
    cos_xy = torch.sum(x * y, dim=-1, keepdim=True) / (x_norm * y_norm)
    theta = torch.acos(cos_xy.clamp(min=-1.0, max=1.0))
    quotient = (x_norm * (y_norm ** 2 + 1)) / (y_norm * (x_norm ** 2 + 1)).clamp_min(eps)
    sin_theta = torch.sin(theta)
    pos_idx = 0.5 * (torch.sign(sin_theta) + 1)
    sin_theta = pos_idx * sin_theta.clamp_min(eps) + (1 - pos_idx) * sin_theta.clamp_max(-eps)
    alpha = torch.atan((quotient - torch.cos(theta)) / sin_theta)
    quotient = (x_norm ** 2 + 1).pow(2) / (2 * x_norm * torch.cos(alpha)).pow(2).clamp_min(eps)
    R = torch.sqrt((quotient - 1).clamp_min(eps))
    p_norm = torch.sqrt((R ** 2 + 1).clamp_min(eps)) - R
    if not return_coord:
        return 2 * arctanh(p_norm.clamp_min(eps))
    else:
        b = p_norm * torch.abs(torch.sin(alpha) / torch.sin(theta).clamp_min(eps))
        a = p_norm * torch.cos(alpha) - b * torch.cos(theta)
        return a * (x / x_norm) + b * (y / y_norm)

