import numpy as np
import scipy
from scipy.special import gammaln, digamma

from util import isPD, nearestPD, symmetrize

def model_conjugate_update(model_prior_params, S, N, project=False, size=1, return_kl=False, k=1):
    # lambd is the precision matrix not covariance! (see page 3 of paper)

    mu_0, lambda_0, a_0, b_0 = model_prior_params

    # Eq. (1) in the paper
    lambda_n = k*S['XX'] + lambda_0
    inv_lambda_n = np.linalg.inv(lambda_n)

    if not isPD(lambda_n):
        lambda_n = nearestPD(lambda_n)

    if not isPD(inv_lambda_n):
        inv_lambda_n = nearestPD(inv_lambda_n)

    mu_n = inv_lambda_n.dot(k*S['Xy'] + lambda_0.dot(mu_0))
    a_n = a_0 + .5 * k*N
    b_n = b_0 + .5 * (k*S['yy'] + mu_0.T.dot(lambda_0).dot(mu_0) - mu_n.T.dot(lambda_n).dot(mu_n))[0, 0]

    if return_kl:
        # normal inverse gamma KL from https://statproofbook.github.io/P/ng-kl.html
        # according to normal inverse gamma, x | N (mu, sigma^2 / small gamma)
        # from wiki, kl of gamma = inv gamma: https://en.wikipedia.org/wiki/Inverse-gamma_distribution
        #
        gamma_kl = a_0 * np.log(b_n/b_0) - (gammaln(a_n) - gammaln(a_0)) + (a_n-a_0) * digamma(a_n) - (b_n - b_0) * a_n / b_n

        diff_mu = mu_0 - mu_n
        squared_diff_mu = np.squeeze(diff_mu.T @ lambda_0 @ diff_mu)

        _, logdet_0 = np.linalg.slogdet(lambda_0)
        _, logdet_n = np.linalg.slogdet(lambda_n)
        # expected(sigma^2) = b / (a-1)
        # expect_normal_kl =  b_n / (a_n - 1) * squared_diff_mu[0,0] + np.trace(lambda_0 @ inv_lambda_n) - logdet_0 + logdet_n - len(mu_0)
        # or should we be taking expectation over (1/sigma^2) = a/b instead?
        expect_normal_kl = a_n / b_n * squared_diff_mu + np.trace(lambda_0 @ inv_lambda_n) - logdet_0 + logdet_n - len(mu_0)
        expect_normal_kl = expect_normal_kl / 2
        kl = gamma_kl + expect_normal_kl
        # print("True KL", round(gamma_kl + expect_normal_kl, 6), "=InvGamma KL", round(gamma_kl, 6), " + expected normal KL", round(expect_normal_kl,6))
        print("True KL", round(kl, 6), "=InvGamma KL", round(gamma_kl, 6), " + expected normal KL", round(expect_normal_kl,6))
    if project:
        b_n = max(b_n, .1)

    sigma_squared = scipy.stats.invgamma.rvs(a=a_n, scale=b_n, size=size)

    theta = np.array([scipy.stats.multivariate_normal.rvs(mu_n.flatten(), symmetrize(ss * inv_lambda_n)) for ss in sigma_squared])

    if return_kl:
        return theta, sigma_squared, kl
    else:
        return theta, sigma_squared
