import jax
from jax import vmap
import jax.numpy as np
import numpy as onp
from jax import random

import utils
from utils import mean_f, vec_f


def _bootstrap_resample(data, key):
    n_population = data.shape[0]
    ind = random.randint(key, (n_population,), 0, n_population)
    return np.take(data, ind, 0)


def _jackknife_resample(data, i):
    return np.roll(data, i, axis=0)[:-1]


def _percentile_of_score(a, score):
    return ((a <= score).sum() + (a < score).sum()) / a.size / 2


def _bca_interval(theta_hat, theta_hat_b, theta_hat_i, alpha):
    from jax.scipy.special import ndtri, ndtr

    percentile = _percentile_of_score(theta_hat_b, theta_hat)
    z0_hat = jax.scipy.special.ndtri(percentile)

    theta_hat_dot = theta_hat_i.mean()
    num = ((theta_hat_dot - theta_hat_i) ** 3).sum()
    den = 6 * ((theta_hat_dot - theta_hat_i) ** 2).sum() ** (3 / 2)
    a_hat = num / den

    z_alpha = ndtri(alpha)
    z_1alpha = -z_alpha
    num1 = z0_hat + z_alpha
    alpha_1 = ndtr(z0_hat + num1 / (1 - a_hat * num1))
    num2 = z0_hat + z_1alpha
    alpha_2 = ndtr(z0_hat + num2 / (1 - a_hat * num2))
    ci_l = np.percentile(theta_hat_b, alpha_1 * 100)
    ci_u = np.percentile(theta_hat_b, alpha_2 * 100)
    return ci_l, ci_u


def _percentile_interval(theta_hat_b, alpha):
    ci_l = np.percentile(theta_hat_b, alpha * 100)
    ci_u = np.percentile(theta_hat_b, (1 - alpha) * 100)
    return ci_l, ci_u


def _basic_interval(theta_hat, theta_hat_b, alpha):
    ci_l = np.percentile(theta_hat_b, alpha * 100)
    ci_u = np.percentile(theta_hat_b, (1 - alpha) * 100)
    ci_l, ci_u = 2 * theta_hat - ci_u, 2 * theta_hat - ci_l
    return ci_l, ci_u


def bootstrap(
    data,
    statistic,
    keys,
    confidence_level=0.95,
    method="BCa",
    statistic_kw={},
    loop_kwargs={},
):
    alpha = (1 - confidence_level) / 2
    theta_hat = statistic(data, **statistic_kw)
    theta_hat_b = vec_f(
        lambda k: statistic(_bootstrap_resample(data, k), **statistic_kw),
        keys,
        **loop_kwargs
    )
    if method == "BCa":
        theta_hat_i = vec_f(
            lambda i: statistic(_jackknife_resample(data, i), **statistic_kw),
            np.arange(0, data.shape[0]),
            **loop_kwargs
        )
        #  theta_hat_dot = theta_hat_i.mean(0)
        percentile_fun = _bca_interval
        for _ in range(len(theta_hat.shape)):
            percentile_fun = vmap(percentile_fun, (0, 0, 0, None))
        return percentile_fun(
            theta_hat,
            np.moveaxis(theta_hat_b, 0, -1),
            np.moveaxis(theta_hat_i, 0, -1),
            alpha,
        )
    elif method == "percentile":
        percentile_fun = _percentile_interval
        for _ in range(len(theta_hat.shape)):
            percentile_fun = vmap(percentile_fun, (0, None))
        return percentile_fun(np.moveaxis(theta_hat_b, 0, -1), alpha)
    elif method == "basic":
        percentile_fun = _basic_interval
        for _ in range(len(theta_hat.shape)):
            percentile_fun = vmap(percentile_fun, (0, 0, None))
        return percentile_fun(theta_hat, np.moveaxis(theta_hat_b, 0, -1), alpha)
    else:
        raise ValueError("Unknown bootstrap method.")
