import joblib
from joblib import Parallel, delayed
import numpy as np
from sklearn.cluster import KMeans
from sklearn.neural_network import MLPRegressor
from tqdm.auto import tqdm


class KernelRegressor(object):
  def __init__(self, num_clusters=10, reg=1e-6):
    self.num_clusters = num_clusters
    self.reg = reg

  def fit(self, features, Y):
    kmeans = KMeans(self.num_clusters)
    self.centers = kmeans.fit(features).cluster_centers_
    dX2 = np.square(self.centers).sum(axis=-1)
    dXY = np.einsum("ik,jk->ij", self.centers, self.centers)
    d2 = dX2[:, np.newaxis] - 2 * dXY + dX2[np.newaxis, :]
    self.kernel_width = np.sqrt(np.sort(d2, axis=-1)[:, 1].mean())

    dX2 = np.square(features).sum(axis=-1)
    self.dY2 = np.square(self.centers).sum(axis=-1)
    dXY = np.einsum("ik,jk->ij", features, self.centers)
    d2 = dX2[:, np.newaxis] - 2 * dXY + self.dY2[np.newaxis, :]
    d2 -= d2.min(axis=-1)[:, np.newaxis]

    X = np.exp(- d2 / (2 * np.square(self.kernel_width)))
    X /= X.sum(axis=-1)[:, np.newaxis]
    self.Theta = np.linalg.inv(X.T.dot(X) + self.reg * np.eye(self.num_clusters)).dot(X.T.dot(Y))

  def predict(self, features):
    dX2 = np.square(features).sum(axis=-1)
    dXY = np.einsum("ik,jk->ij", features, self.centers)
    d2 = dX2[:, np.newaxis] - 2 * dXY + self.dY2[np.newaxis, :]
    d2 -= d2.min(axis=-1)[:, np.newaxis]

    X = np.exp(- d2 / (2 * np.square(self.kernel_width)))
    X /= X.sum(axis=-1)[:, np.newaxis]
    Y = X.dot(self.Theta)
    return Y


class DiffusionPrior(object):
  def __init__(self, d, T, alpha, reg=1e-6, tol=1e-4, hidden_size=None):
    self.d = d
    self.T = T
    self.alpha = alpha * np.ones(self.T + 1)
    self.alpha[0] = 1.0

    self.beta = 1 - self.alpha
    self.alpha_bar = np.cumprod(self.alpha)
    self.beta_tilde = np.zeros(self.T + 1)
    self.beta_tilde[1 :] = (1 - self.alpha_bar[: self.T]) * self.beta[1 :] / (1 - self.alpha_bar[1 :])

    # reverse process parameterization
    self.reg = reg  # least-squares regularization
    self.tol = tol  # fitting stopping tolerance
    if hidden_size is None:
      self.hidden_size = 10 * self.d
    else:
      self.hidden_size = hidden_size

  def train_stage(self, t, St, epsilon):
    # Equation 12 in Ho et al. (2020)
    # Denoising Diffusion Probabilistic Models
    regressor = MLPRegressor(hidden_layer_sizes=(self.hidden_size, self.hidden_size),
      alpha=self.reg, early_stopping=False, verbose=False, tol=self.tol)
    regressor.fit(St, epsilon)
    error = np.sqrt(np.square(regressor.predict(St) - epsilon).sum(axis=-1).mean())
    return regressor, error

  def train(self, S0):
    n = S0.shape[0]

    # diffusion using the forward process
    epsilon = np.random.randn(self.T + 1, n, self.d)
    S = np.zeros((self.T + 1, n, self.d))
    S[0, :, :] = S0
    for t in range(1, self.T + 1):
      S[t, :, :] = np.sqrt(self.alpha_bar[t]) * S0 + np.sqrt(1 - self.alpha_bar[t]) * epsilon[t, :, :]

    # reverse process learning
    output = Parallel(n_jobs=-1)(
      delayed(self.train_stage)(t, S[t, :, :], epsilon[t, :, :]) for t in tqdm(range(1, self.T + 1)))

    self.regressors = []
    self.regressors.append(None)
    errors = np.zeros(self.T)
    for t in range(self.T):
      self.regressors.append(output[t][0])
      errors[t] = output[t][1]

    return errors

  def conditional_prior_mean(self, S, t):
    # Algorithm 2 in Ho et al. (2020)
    # Denoising Diffusion Probabilistic Models
    epsilon = self.regressors[t].predict(S)
    S0 = (S - np.sqrt(1 - self.alpha_bar[t]) * epsilon) / np.sqrt(self.alpha_bar[t])
    w0 = np.sqrt(self.alpha_bar[t - 1]) * self.beta[t] / (1 - self.alpha_bar[t])
    wt = np.sqrt(self.alpha[t]) * (1 - self.alpha_bar[t - 1]) / (1 - self.alpha_bar[t])
    mu = w0 * S0 + wt * S
    return mu

  def sample(self, n):
    # reverse process sampling
    S = np.zeros((self.T + 1, n, self.d))
    S[self.T, :, :] = np.random.randn(n, self.d)
    for t in range(self.T, 0, -1):
      mu = self.conditional_prior_mean(S[t, :, :], t)
      S[t - 1, :, :] = mu + np.sqrt(self.beta_tilde[t]) * np.random.randn(n, self.d)
      S[t - 1, :, :] = np.minimum(np.maximum(S[t - 1, :, :], -100), 100)  # numerical stability (just in case)

    return S

  def posterior_sample(self, theta_bar, Sigma_bar):
    # reverse process sampling with evidence
    S = np.zeros((self.T + 1, self.d))
    for t in range(self.T + 1, 0, -1):
      # diffused evidence
      theta_diff = np.sqrt(self.alpha_bar[t - 1]) * theta_bar
      Sigma_diff = self.alpha_bar[t - 1] * Sigma_bar
      Lambda_diff = np.linalg.inv(Sigma_diff)

      # posterior distribution
      if t == self.T + 1:
        Sigma_hat = np.linalg.inv(np.eye(self.d) + Lambda_diff)
        mu_hat = Sigma_hat.dot(Lambda_diff.dot(theta_diff))
      else:
        mu = np.squeeze(self.conditional_prior_mean(S[[t], :], t))
        Sigma = np.maximum(self.beta_tilde[t], 1e-6) * np.eye(self.d)  # zero covariance in stage 1 of the reverse process
        Lambda = np.linalg.inv(Sigma)
        Sigma_hat = np.linalg.inv(Lambda + Lambda_diff)
        mu_hat = Sigma_hat.dot(Lambda.dot(mu) + Lambda_diff.dot(theta_diff))

      # posterior sampling
      S[t - 1, :] = np.random.multivariate_normal(mu_hat, Sigma_hat)
      S[t - 1, :] = np.minimum(np.maximum(S[t - 1, :], -100), 100)  # numerical stability (just in case)

    return S

  def posterior_sample_map(self, map_lambda):
    # reverse process sampling with evidence
    S = np.zeros((self.T + 1, self.d))
    for t in range(self.T + 1, 0, -1):
      # posterior distribution
      if t == self.T + 1:
        mu0 = np.zeros(self.d)
        Sigma0 = np.eye(self.d) / self.alpha_bar[t - 1]
      else:
        mu = np.squeeze(self.conditional_prior_mean(S[[t], :], t))
        Sigma = np.maximum(self.beta_tilde[t], 1e-6) * np.eye(self.d)  # zero covariance in stage 1 of the reverse process
        mu0 = mu / np.sqrt(self.alpha_bar[t - 1])
        Sigma0 = Sigma / self.alpha_bar[t - 1]

      mu_hat, Sigma_hat = map_lambda(mu0, Sigma0)
      mu_hat *= np.sqrt(self.alpha_bar[t - 1])
      Sigma_hat *= self.alpha_bar[t - 1]

      # posterior sampling
      S[t - 1, :] = np.random.multivariate_normal(mu_hat, Sigma_hat)
      S[t - 1, :] = np.minimum(np.maximum(S[t - 1, :], -100), 100)  # numerical stability (just in case)

    return S

  def posterior_sample_grad(self, loglik_grad):
    # reverse process sampling with evidence
    S = np.zeros((self.T + 1, self.d))
    for t in range(self.T + 1, 0, -1):
      # posterior distribution
      if t == self.T + 1:
        s0 = np.zeros(self.d)
        mu = np.zeros(self.d)
        Sigma = np.eye(self.d)
      else:
        # epsilon to score conversion based on (29) in Chung et al. (2023)
        # Diffusion Posterior Sampling for General Noisy Inverse Problems
        epsilon = np.squeeze(self.regressors[t].predict(S[[t], :]))
        score = - epsilon / np.sqrt(1 - self.alpha_bar[t])
        s0 = (S[t, :] + (1 - self.alpha_bar[t]) * score) / np.sqrt(self.alpha_bar[t])
        w0 = np.sqrt(self.alpha_bar[t - 1]) * self.beta[t] / (1 - self.alpha_bar[t])
        wt = np.sqrt(self.alpha[t]) * (1 - self.alpha_bar[t - 1]) / (1 - self.alpha_bar[t])
        mu = w0 * s0 + wt * S[t, :]
        Sigma = np.maximum(self.beta_tilde[t], 1e-6) * np.eye(self.d)  # zero covariance in stage 1 of the reverse process

      # posterior sampling
      S[t - 1, :] = np.random.multivariate_normal(mu, Sigma) + loglik_grad(s0)
      S[t - 1, :] = np.minimum(np.maximum(S[t - 1, :], -100), 100)  # numerical stability (just in case)

    return S
