import math
import tqdm
import jax
import blackjax
import numpyro
import jax.numpy as jnp
from jax import grad, lax, vmap, jit, pmap
from jax.tree_util import Partial as partial

def nuts_warmup(key, sample, logprob_fun):
    key_warmup, key_nuts = jax.random.split(key, 2)
    res = blackjax.window_adaptation(blackjax.nuts, logprob_fun).run(key_warmup, {"loc": sample})
    inverse_mass_matrix = res[-1][2].imm_state.inverse_mass_matrix[-1]
    step_size = jnp.exp(res[-1][2].ss_state.log_step_size_avg[-1])
    state = res[-1][0].position['loc'][-1]
    return state, step_size, inverse_mass_matrix

def bkjx_loop(key, init_state, kernel, steps):
    keys = jax.random.split(key, steps)

    def one_step(i, state):
        state, _ = kernel.step(keys[i], state)
        return state

    return jax.lax.fori_loop(0, steps, one_step, kernel.init(init_state))

def nuts_once(key, state, step_size, inverse_mass_matrix, logprob_fun, num_steps):
    nuts = blackjax.nuts(logprob_fun, inverse_mass_matrix=inverse_mass_matrix, step_size=step_size)
    nuts_sample = bkjx_loop(key, kernel=nuts, init_state={"loc": state},
                            steps=num_steps).position['loc']
    return nuts_sample

def sample_nuts(D_X, key, mixt_dist_jax, n_chains, posterior_logprob_nuts, num_steps=1000):
    # def posterior_logprob_nuts(x):
    #     return - ((measure_jax - operator_jax @ x['loc']) ** 2).sum(axis=-1) / (
    #             2 * sigma_y ** 2) + mixt_dist_jax.log_prob(x['loc'])

    key_warmup, key_nuts, key_categorical = jax.random.split(key, 3)
    batch_size = len(jax.devices("cpu"))
    n_batches = math.ceil(n_chains / batch_size)
    nuts_warmup_fun = pmap(partial(nuts_warmup, logprob_fun=posterior_logprob_nuts),
                           devices=jax.devices("cpu"))
    nuts_warmups = [nuts_warmup_fun(
        jax.random.split(k, batch_size),
        mixt_dist_jax.sample(k,
                             sample_shape=(batch_size,))) for k in
        jax.random.split(key_warmup, n_batches)]
    initial_positions = jnp.concatenate([x[0] for x in nuts_warmups])
    step_sizes = jnp.concatenate([x[1] for x in nuts_warmups])
    inverse_mass_matrixes = jnp.concatenate([x[2] for x in nuts_warmups])
    logits = vmap(posterior_logprob_nuts)({"loc": initial_positions})
    ancestors = numpyro.distributions.Categorical(logits=logits).sample(key_categorical,
                                                                        sample_shape=(logits.shape[0],))
    initial_positions = initial_positions[ancestors]
    step_sizes = step_sizes[ancestors]
    inverse_mass_matrixes = inverse_mass_matrixes[ancestors]
    nuts_sampler = pmap(partial(nuts_once, logprob_fun=posterior_logprob_nuts, num_steps=num_steps),
                        devices=jax.devices("cpu"))
    nuts_samples = jnp.concatenate([nuts_sampler(
        jax.random.split(k, batch_size),
        init,
        step,
        inv_mass)
        for k, init, step, inv_mass in zip(jax.random.split(key, n_batches),
                                                     initial_positions.reshape(-1, batch_size, initial_positions.shape[-1]),
                                                     step_sizes.reshape(-1, batch_size),
                                                     inverse_mass_matrixes.reshape(-1, batch_size, *inverse_mass_matrixes.shape[1:]))
    ])
    return nuts_samples, initial_positions