import os
import math
import jax
import matplotlib.pyplot as plt
import seaborn as sns
import hydra
import numpy as np
import jax.numpy as jnp
from jax import random

from dist import GaussMixJax, get_posterior
from util import diffusion_sampler
from util import generate_inverse_problem_gm, sliced_wasserstein
from numpyro.distributions import MultivariateNormal, Categorical, MixtureSameFamily


@hydra.main(version_base=None, config_path="configs/", config_name="gmm")
def main(cfg):
    example_cfg = cfg.example
    key = jax.random.PRNGKey(cfg.random_key)
    n_groups = cfg.n_groups # each group is individual
    n_particles = example_cfg.n_particles
    n_steps = example_cfg.n_steps
    T = example_cfg.T
    dim = example_cfg.dim
    dim_y = example_cfg.dim_y
    kappa_range = example_cfg.kappa_range
    snr_range = example_cfg.snr_range
    noise = cfg.noise
    n_samples = cfg.n_samples
    dts = jnp.ones(n_steps) * T / n_steps

    key, subkey = random.split(key)
    center = []
    scale = 1
    for i in range(-2, 3):
        center += [jnp.array([-8.*scale*i, -8.*scale*j]*(dim//2)) for j in range(-2, 3)]
    weights = random.uniform(subkey, (len(center),))**2
    weights = weights / weights.sum()
    target = GaussMixJax(jnp.array(center), weights, scale)

    # ===== generate problem =====
    key, subkey = random.split(key)
    y, A, Sigma_y, x_origin = generate_inverse_problem_gm(subkey, dim, dim_y, target, scale, kappa_range, snr_range)
    var_y = Sigma_y[0][0]
    U, D, V = jnp.linalg.svd(A, full_matrices=False)
    V = V.T
    eps = 1e-2
    lambda_max = D.max()**2
    T_denoise = np.log(1+var_y/lambda_max)/2
    covs = jnp.repeat(jnp.eye(dim)[None]*(target.var_scale**2*jnp.exp(-(T_denoise-eps)*2)+1-jnp.exp(-(T_denoise-eps)*2)), axis=0, repeats=target.n_center)
    prior_middle = MixtureSameFamily(
        mixing_distribution=Categorical(target.weights),
        component_distribution=MultivariateNormal(jnp.array(center)*jnp.exp(-(T_denoise-eps)), covariance_matrix=covs)
    )
    A_singular = (jnp.exp(2*(T_denoise-eps)) / (1 + var_y/D**2 - jnp.exp(2*(T_denoise-eps))))**(1/2)
    A_singular = jnp.diag(A_singular) @ V.T
    aux_singular = (jnp.exp(2*(T_denoise-eps))/((1-jnp.exp(2*(T_denoise-eps)))*D**2+var_y))**(1/2)
    obs_singular = jnp.diag(aux_singular) @ (jnp.exp(-(T_denoise-eps)) * (U.T @ y))
    n_step_denoise = math.ceil(T_denoise / T * n_steps)
    dts_denoise = jnp.ones(n_step_denoise) * (T_denoise-eps) / n_step_denoise
    print(f"T denoise: {T_denoise}, n_step_denoise: {n_step_denoise}")
    key, subkey = random.split(key)
    # same samples as in Langevin or NUTS since so far keys are always the same
    samples_true_full = target.sample_posterior(key, (n_samples,))
    plot_dim = 0

    # ===== reverse SDE from analytical boost =====
    singular_init = get_posterior(obs_singular, prior_middle, A_singular, jnp.eye(dim_y))
    x0s_boost = singular_init.sample(key, (n_groups, n_particles))
    key, subkey = random.split(key)
    keys = jax.random.split(subkey, n_groups)
    keys = jnp.array([jax.random.split(k, n_particles) for k in keys])
    samples_boost_full = diffusion_sampler(keys, x0s_boost, dts_denoise, cfg, target.score_fn)

    key, subkey = random.split(key)
    samples_true_full = target.sample_posterior(key, (n_samples,))
    plot_dim = 0
    if cfg.plot:
        samples_boost = samples_boost_full[..., plot_dim].reshape([-1])
        samples_true = samples_true_full[:, plot_dim]
        fig, ax = plt.subplots(1, 1, figsize=(6, 4))
        cmap = plt.get_cmap("tab10")
        sns.kdeplot(samples_boost, color='k', ax=ax, label='Samples')
        sns.kdeplot(samples_true, color=cmap(0), ax=ax, label='True Posterior')
        sns.kdeplot(x0s_boost[..., plot_dim].reshape([-1]), color=cmap(1), ax=ax, label='Boosted Posterior')
        plt.legend(); plt.show()
        print(samples_boost.mean(), samples_boost.std())
        print(samples_true.mean(), samples_true.std())
    sample1 = samples_true_full
    sample2 = samples_boost_full.reshape((-1, dim))
    key, wass_subkey = random.split(key)
    print("Sliced wasserstein distance: {}".format(sliced_wasserstein(wass_subkey, sample1, sample2, n_slices=100)))

    if cfg.exp_suffix:
        save_path = os.path.join(cfg.save_path, "dx{}dy{}_{}".format(dim, dim_y, cfg.exp_suffix))
    else:
        save_path = os.path.join(cfg.save_path, "dx{}dy{}".format(dim, dim_y))
    os.makedirs(save_path, exist_ok=True)
    jnp.savez(
        os.path.join(save_path, 'rdkey{}'.format(cfg.random_key)),
        samples_true=samples_true_full,
        samples_boost=samples_boost_full
    )



if __name__ == "__main__":
    main()
