import numpy as np
import scipy
import scipy.stats

from NIG import NIG_rvs, NIG_rvs_single_variance
from NIW import NIW_rvs, NIW_conjugate_update
from Gibbs_suff_stat_update import update_sufficient_statistics, calc_Cov_xx_xx
from util import isPD, nearestPD, project_suff_stats, symmetrize, calc_posterior_params, fast_sample_multivariate_normal, cho_inverse
from model_conjugate_update import model_conjugate_update

import time
from copy import deepcopy

def Gibbs(data_prior_params,
          model_prior_params,
          N,
          sigma_DP_noise,
          Z,
          num_burnin,
          num_iterations,
          gibbs_flavor,
          X_fourth_moment=None, k=1):

    # theta, sigma^2, \mu_x, \tau^2, \omega^2
    model_params, data_params, noise_covariance = initialize_values(data_prior_params, model_prior_params, sigma_DP_noise, Z, gibbs_flavor)

    if gibbs_flavor == 'gibbs-Isserlis':
        X_second_moment, X_fourth_moment, Cov_xx_xx = calc_Isserlis_moments(data_params)
    else:
        X_second_moment = X_fourth_moment[:, :, -1, -1]

        Cov_xx_xx = calc_Cov_xx_xx(X_second_moment, X_fourth_moment)
        if not isPD(X_second_moment):
            X_second_moment = nearestPD(X_second_moment)

    theta = np.zeros((num_iterations, Z['X'].shape[0]))
    sigma_squared = np.zeros(num_iterations)
    for iteration in range(num_iterations + num_burnin):

        S = update_sufficient_statistics(X_second_moment, Cov_xx_xx, Z, model_prior_params, model_params, noise_covariance, N)

        sumS = dict()
        for key in S:
            sumS[key] = k * S[key]
        model_params = update_model_params(sumS, model_prior_params, k*N)

        X_second_moment, X_fourth_moment, Cov_xx_xx, data_params = update_moments(data_prior_params, X_second_moment, X_fourth_moment, Cov_xx_xx, S, N, gibbs_flavor)

        # noise_covariance = update_noise_covariance(S, Z, epsilon_S, sensitivity)

        if iteration >= num_burnin:
            theta[iteration - num_burnin, :] = model_params[0].flatten()
            sigma_squared[iteration - num_burnin] = model_params[1]

    # calc_posterior_params(S, N, model_prior_params, printkl=False)

    return theta, sigma_squared

def Gibbs_share(data_prior_params,
          model_prior_params,
          multiN,
          sigma_DP_noise,
          multiZ,
          num_burnin,
          num_iterations,
          gibbs_flavor,
          X_fourth_moment=None, k=1):

    if gibbs_flavor != 'gibbs-Isserlis-ind':
        raise ValueError(f'Unrecognized gibbs flavor for Gibbs_share! ({gibbs_flavor})')

    nP = len(multiZ)
    N = np.sum(multiN)
    d = multiZ[0]['Xy'].shape[0] - 1

    # theta, sigma^2, \mu_x, \tau^2, \omega^2
    model_params, data_params, noise_covariance_ind = initialize_values(data_prior_params, model_prior_params, sigma_DP_noise, multiZ[0], gibbs_flavor)

    X_second_moment, X_fourth_moment, Cov_xx_xx = calc_Isserlis_moments(data_params)


    theta = np.zeros((num_iterations, d+1))
    sigma_squared = np.zeros(num_iterations)
    for iteration in range(num_iterations + num_burnin):

        sumS = {'XX': np.zeros([d+1, d+1]), 'Xy': np.zeros([d+1, 1]), 'yy': 0.0, 'X': np.zeros([d+1, 1])}

        for i in range(nP):
            S = update_sufficient_statistics(X_second_moment, Cov_xx_xx, multiZ[i],
                                             model_prior_params, model_params, noise_covariance_ind[i], multiN[i])
            sumS = {key: sumS[key] + val for key, val in S.items()}

        X_second_moment, X_fourth_moment, Cov_xx_xx, data_params = update_moments(data_prior_params, X_second_moment, X_fourth_moment, Cov_xx_xx, sumS, N, gibbs_flavor)

        for key in sumS:
            sumS[key] *= k
        model_params = update_model_params(sumS, model_prior_params, k * N)

        # noise_covariance = update_noise_covariance(S, Z, epsilon_S, sensitivity)
        if iteration >= num_burnin:
            theta[iteration - num_burnin, :] = model_params[0].flatten()
            sigma_squared[iteration - num_burnin] = model_params[1]

    # calc_posterior_params(S, N, model_prior_params, printkl=False)

    return theta, sigma_squared

def Gibbs_ind(data_prior_params,
          model_prior_params,
          multiN,
          sigma_DP_noise,
          multiZ,
          num_burnin,
          num_iterations,
          gibbs_flavor,
          X_fourth_moment=None, k=1):

    if gibbs_flavor != 'gibbs-Isserlis-ind':
        raise ValueError(f'Unrecognized gibbs flavor for Gibbs_ind! ({gibbs_flavor})')

    nP = len(multiZ)
    N = np.sum(multiN)
    d = multiZ[0]['Xy'].shape[0] - 1

    # theta, sigma^2, \mu_x, \tau^2, \omega^2
    model_params, data_params, noise_covariance_ind = initialize_values(data_prior_params, model_prior_params, sigma_DP_noise, multiZ[0], gibbs_flavor)

    X_second_moment, X_fourth_moment, Cov_xx_xx = calc_Isserlis_moments(data_params)

    # independent moment computation for each party
    X_second_moment = np.tile(X_second_moment, [nP, 1, 1])
    X_fourth_moment = np.tile(X_fourth_moment, [nP, 1, 1, 1, 1])
    Cov_xx_xx = np.tile(Cov_xx_xx, [nP, 1, 1, 1, 1])


    theta = np.zeros((num_iterations, d+1))
    sigma_squared = np.zeros(num_iterations)
    for iteration in range(num_iterations + num_burnin):

        sumS = {'XX': np.zeros([d+1, d+1]), 'Xy': np.zeros([d+1, 1]), 'yy': 0.0, 'X': np.zeros([d+1, 1])}

        for i in range(nP):
            S = update_sufficient_statistics(X_second_moment[i], Cov_xx_xx[i], multiZ[i],
                                             model_prior_params, model_params, noise_covariance_ind[i], multiN[i])

            X_second_moment[i], X_fourth_moment[i], Cov_xx_xx[i], data_params = update_moments(data_prior_params,
                                                                                  X_second_moment[i], X_fourth_moment[i],
                                                                                  Cov_xx_xx[i], S, multiN[i], gibbs_flavor)

            sumS = {key: sumS[key] + val for key, val in S.items()}

        # aggregated sufficient statistics
        # sumS = {'XX': np.zeros([d+1, d+1]), 'Xy': np.zeros([d+1, 1]), 'yy': 0.0, 'X': np.zeros([d+1, 1])}
        # for i in range(nP):
        #     sumS = {key: sumS[key] + val for key, val in multiS[i].items()}

        for key in sumS:
            sumS[key] *= k
        model_params = update_model_params(sumS, model_prior_params, k * N)

        # for i in range(nP):
        #     X_second_moment[i], X_fourth_moment[i], Cov_xx_xx[i], data_params = update_moments(data_prior_params,
        #                                                                           X_second_moment[i], X_fourth_moment[i],
        #                                                                           Cov_xx_xx[i], multiS[i], multiN[i], gibbs_flavor)

        # noise_covariance = update_noise_covariance(S, Z, epsilon_S, sensitivity)

        if iteration >= num_burnin:
            theta[iteration - num_burnin, :] = model_params[0].flatten()
            sigma_squared[iteration - num_burnin] = model_params[1]

    return theta, sigma_squared

def initialize_values(data_prior_params, model_prior_params, sigma_DP_noise, Z, gibbs_flavor):

    model_params = NIG_rvs(*model_prior_params)

    if gibbs_flavor == 'gibbs-Isserlis' or gibbs_flavor == 'gibbs-Isserlis-ind':
        data_params = NIW_rvs(*data_prior_params)

    else:
        data_params = None

    dim = Z['X'].shape[0] ** 2 + Z['X'].shape[0] + 1  # = dim(XX^T) + dim(xy) + dim(y)

    if gibbs_flavor == 'gibbs-Isserlis-ind':
        nP = sigma_DP_noise.size
        noise_covariance = np.zeros([nP, dim, dim])
        for i in range(nP):
            noise_covariance[i] = np.diag(np.ones(dim)*(sigma_DP_noise[i]**2))
    else:
        noise_covariance = np.diag(np.ones(dim)*(sigma_DP_noise**2))

    return model_params, data_params, noise_covariance


def update_moments(data_prior_params, Ex2, Ex4, Cov_xx_xx, S, N, gibbs_flavor):

    if gibbs_flavor == 'gibbs-Isserlis' or 'gibbs-Isserlis-ind':
        # data_params = NIW_conjugate_update(project_suff_stats(S), data_prior_params, N)
        data_params = NIW_conjugate_update(S, data_prior_params, N)
        Ex2, Ex4, Cov_xx_xx = calc_Isserlis_moments(data_params)
    elif gibbs_flavor in ['gibbs-noisy', 'gibbs-exact', 'gibbs-prior']:
        data_params = None
    else:
        raise ValueError(f'Unrecognized moments source! ({gibbs_flavor})')

    return Ex2, Ex4, Cov_xx_xx, data_params


def calc_Isserlis_moments(data_params):

    mu_x, Tau = data_params

    if isinstance(mu_x, float):
        mu_x = np.array([[mu_x]])
        Tau = np.array([[Tau]])

    # if np.any(mu_x):
    #      raise ValueError('The Isserlis moment computing method only works for p(x) with zero mean!')

    # Ex4 = calc_Isserlis_X_fourth_moment(mu_x, Tau)
    Ex4 = calc_X_fourth_moment(mu_x, Tau)

    Ex2 = Ex4[:, :, -1, -1]

    Cov_xx_xx = calc_Cov_xx_xx(Ex2, Ex4)

    if not isPD(Ex2):
        Ex2 = nearestPD(Ex2)

    return Ex2, Ex4, Cov_xx_xx

def calc_X_fourth_moment(mu_x, Tau):

    d = len(mu_x) + 1
    Ex4 = np.zeros([d, d, d, d])
    mu1 = np.zeros([d, 1])
    Tau1 = np.zeros([d, d])

    mu1[:-1] = mu_x
    mu1[-1] = 1
    Tau1[:-1, :-1] = Tau
    # for i in range(d):
    #     for j in range(d):
    #         for k in range(d):
    #             for l in range(d):
    #                 Ex4[i, j, k, l] = mu1[i]*mu1[j]*mu1[k]*mu1[l] +\
    #                                   mu1[i]*mu1[j]*Tau1[k,l] + mu1[i]*mu1[k]*Tau1[j,l] + mu1[i]*mu1[l]*Tau1[j,k] +\
    #                                   mu1[j]*mu1[k]*Tau1[i,l] + mu1[j]*mu1[l]*Tau1[i,k] + mu1[k]*mu1[l]*Tau1[i,j] +\
    #                                   Tau1[i,j]*Tau1[k,l] + Tau1[i,k]*Tau1[j,l] + Tau1[i,l]*Tau1[j,k]

    flat_mu = mu1.ravel()
    Ex4 = np.einsum('i,j,k,l->ijkl', *[flat_mu]*4)
    for einstr in ['i,j,kl','i,k,jl','i,l,jk','j,k,il','j,l,ik','k,l,ij']:
        Ex4 += np.einsum(einstr + '->ijkl', flat_mu,flat_mu, Tau1)
    for einstr in ['ij,kl','ik,jl','il,jk']:
        Ex4 += np.einsum(einstr + '->ijkl', Tau1, Tau1)

    return Ex4

def calc_Isserlis_X_fourth_moment(mu_x, Tau):

    d = len(mu_x) + 1
    Ex4 = np.zeros([d, d, d, d])

    for i in range(d-1):
        for j in range(d-1):
            for k in range(d-1):
                for l in range(d-1):
                    Ex4[i, j, k, l] = Tau[i,j]*Tau[k,l] + Tau[i,k]*Tau[j,l] + Tau[i,l]*Tau[j,k]

    for i in range(d-1):
        for j in range(d-1):
            Ex4[i, j, -1, -1] = Tau[i, j]
            Ex4[i, -1, j, -1] = Tau[i, j]
            Ex4[i, -1, -1, j] = Tau[i, j]
            Ex4[-1, i, j, -1] = Tau[i, j]
            Ex4[-1, i, -1, j] = Tau[i, j]
            Ex4[-1, -1, i, j] = Tau[i, j]

    Ex4[-1, -1, -1, -1] = 1.

    return Ex4

def calc_hier_norm_moments(data_params):

    mu_x, Tau = data_params

    if isinstance(mu_x, float):
        mu_x = np.array([[mu_x]])
        Tau = np.array([[Tau]])

    Ex4 = calc_hier_norm_X_fourth_moment(mu_x, Tau)
    Ex2 = Ex4[:, :, -1, -1]

    Cov_xx_xx = calc_Cov_xx_xx(Ex2, Ex4)

    if not isPD(Ex2):
        Ex2 = nearestPD(Ex2)

    return Ex2, Ex4, Cov_xx_xx


def update_noise_covariance(S, Z, epsilon_S, sensitivity):

    abs_noise = np.abs(np.hstack(([Z['XX'].flatten() - S['XX'].flatten(),
                                   Z['Xy'].flatten() - S['Xy'].flatten(),
                                   Z['yy'] - S['yy']])))

    laplace_lambda = sensitivity / epsilon_S

    inverse_variance = np.random.wald(1 / (laplace_lambda * abs_noise), 1 / laplace_lambda ** 2)

    variance = 1 / inverse_variance

    covariance = np.diag(np.array(variance).flatten())

    return covariance


def update_model_params(S, model_prior_params, N):

    mu_n, lambda_n, inv_lambda_n, a_n, b_n = calc_posterior_params(S, N, model_prior_params)

    sigma_squared = scipy.stats.invgamma.rvs(a=a_n, scale=b_n)

    # cov = symmetrize(sigma_squared * np.linalg.inv(lambda_n))
    cov = sigma_squared * inv_lambda_n

    # if not isPD(cov):
    #     cov = nearestPD(cov)

    theta = fast_sample_multivariate_normal(mu_n.ravel(), cov)

    if isinstance(theta, float):
        theta = np.array([theta])

    theta = theta[:, None]

    return theta, sigma_squared
