import numpy as np
from kl_nearest import skl_estimator
from scipy.linalg import sqrtm

def compute_kl_log(samples_from_1, seed=2609, against_posterior=None):
    """
    Computes KL(samples_from_1 a numpy array; against_posterior optional numpy array)
    If against_posterior is not specified, use standard gaussian
    """
    rng = np.random.default_rng(seed)
    m, d = samples_from_1.shape
    if against_posterior is None:
        samples_from_2 = rng.normal(loc=np.zeros(d),scale=np.ones(d),size=(m,d))
    else:
        num = len(samples_from_1)
        samples_from_1, samples_from_2 = rng.choice(against_posterior, size=m, replace=False), samples_from_1
    k = 4
    while True:
        try:
            botht = np.vstack([samples_from_1, samples_from_2])
            mu, std =  np.mean(botht,0), np.std(botht, 0)

            inv_sqrt = np.linalg.inv(sqrtm(np.cov(botht.T)))
            samples_from_1 = (samples_from_1 - mu) @ inv_sqrt
            samples_from_2 = (samples_from_2 - mu) @ inv_sqrt

            kl = skl_estimator(samples_from_1, samples_from_2,k=k, error=True)
            return kl
        except ValueError:
            k += 1


def compute_kl_gibbs(posteriors, method, mode='nearestc', model_prior_params=None,**kwargs):
    """
    Posteriors is a dictionary {'prior': (samples of theta, sigma_squared),
                                method: (samples of theta, sigma_squared)}
    """
    def concat_theta_and_sigma(val):
        return np.hstack((val[0], val[1].reshape(-1,1)))

    if mode == 'analytic' and method == 'non-private':
        return posteriors['non-private'][2]

    elif mode == 'mle-fit':
        from mle_fit import nig_fit, n_fit
        samples = posteriors[method][1], posteriors[method][0]
        kl, _  = nig_fit(samples, model_prior_params)
        # posteriors["mle-fit-{}".format(method)]
        return kl

    elif mode == 'nearestc':
        samples_prior = concat_theta_and_sigma(posteriors['prior'])
        samples_post =  concat_theta_and_sigma(posteriors[method])
        x = np.vstack([samples_prior, samples_post])

        sample_cov = np.cov(x.T)
        sqrt_cov = sqrtm(sample_cov)
        inv_sqrt = np.linalg.inv(sqrt_cov)

        mu, std =  np.mean(x,0), np.std(x, 0)

        samples_post = (samples_post - mu) @ inv_sqrt
        samples_prior = (samples_prior - mu) @ inv_sqrt

        kl = skl_estimator(samples_post, samples_prior,k=4)
        k = 4
        while not np.isfinite(kl):
            k += 2
            kl = skl_estimator(samples_post, samples_prior,k=k)

        if kl < 0: kl = 0
        return kl

    elif mode == 'nearest':
        samples_prior = concat_theta_and_sigma(posteriors['prior'])
        samples_post =  concat_theta_and_sigma(posteriors[method])
        x = np.vstack([samples_prior, samples_post])
        mu, std =  np.mean(x,0), np.std(x, 0)
        # print("before std", skl_estimator(samples_post, samples_prior,k=4))
        samples_post = (samples_post - mu) /std
        samples_prior = (samples_prior - mu) /std
        kl = skl_estimator(samples_post, samples_prior,k=4)
        if kl < 0: kl = 0
        return kl

    elif mode == 'nearestu':
        samples_prior = concat_theta_and_sigma(posteriors['prior'])
        samples_post =  concat_theta_and_sigma(posteriors[method])
        kl = skl_estimator(samples_post, samples_prior,k=4)
        if kl < 0: kl = 0
        return kl
