
import torch
import logging
import numpy as np
from scipy import integrate
import os
import math
import torch
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.spectral_norm as spectral_norm
from torch.distributions import Normal
#os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import sympy
import random
import logging
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from scipy.optimize import minimize
from sympy import *
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, TensorDataset
from libs.iddpm import UNetModel,UNetModel4Pretrained,UNetModel4Pretrained2
from sklearn.model_selection import train_test_split
import copy
import json
import warnings
from absl import app, flags

import torch
from tensorboardX import SummaryWriter
from torchvision.datasets import CIFAR10
from torchvision.utils import make_grid, save_image
from torchvision import transforms
from tqdm import trange
from adan import Adan
FLAGS = flags.FLAGS
flags.DEFINE_bool('train', False, help='train from scratch')
flags.DEFINE_bool('eval', False, help='load ckpt.pt and evaluate FID and IS')
# UNet: IDDPM
flags.DEFINE_integer('in_channel', 3, help='input channel of UNet')
flags.DEFINE_integer('out_channel', 3, help='output channel of UNet')
flags.DEFINE_integer('ch', 128, help='base channel of UNet')
flags.DEFINE_integer('num_res_blocks', 3, help='# resblock in each level')
flags.DEFINE_integer('num_heads', 4, help='Multi-Heads for attention')
flags.DEFINE_integer('dims', 2, help='1,2,3 dims')
flags.DEFINE_multi_integer('ch_mult', [1, 2, 2, 2], help='channel multiplier')
flags.DEFINE_multi_integer('attn', [32 // 16, 32 // 8], help='add attention to these levels')
flags.DEFINE_float('dropout', 0.3, help='dropout rate of resblock')
flags.DEFINE_bool('use_scale_shift_norm', True, help='load ckpt.pt and evaluate FID and IS')

flags.DEFINE_integer('section_begin', 0, help='section to record L')
flags.DEFINE_integer('section_end', 1001, help='section to record L (END)')

flags.DEFINE_integer('head_out_channels', 3, help='the final layer of High order noise network')
flags.DEFINE_enum('mode', 'simple', ['simple','complex'], help='the mode for honn modeling')

# Gaussian Diffusion
flags.DEFINE_float('beta_1', 1e-4, help='start beta value')
flags.DEFINE_float('beta_T', 0.02, help='end beta value')
flags.DEFINE_integer('T', 1000, help='total diffusion training noising steps')
flags.DEFINE_enum('sample_type', 'ddpm', ['ddpm', 'analyticdpm', 'gmddpm','mean_network'], help='sample type for sampling')
flags.DEFINE_enum('mean_type', 'epsilon', ['xprev', 'xstart', 'epsilon'], help='predict variable')
flags.DEFINE_enum('var_type', 'fixedlarge', ['fixedlarge', 'fixedsmall'], help='variance type')
flags.DEFINE_enum('noise_schedule', 'linear', ['linear','cosine'], help='the mode for honn modeling')
# Training
flags.DEFINE_float('lr', 1e-4, help='target learning rate')
flags.DEFINE_float('grad_clip', 1., help="gradient norm clipping")
flags.DEFINE_integer('total_steps', 500001, help='total training steps')
flags.DEFINE_integer('img_size', 32, help='image size')
flags.DEFINE_integer('warmup', 5000, help='learning rate warmup')
flags.DEFINE_integer('batch_size', 128, help='batch size')
flags.DEFINE_integer('num_workers', 4, help='workers of Dataloader')
flags.DEFINE_integer('noise_order', 1, help="the order of noise used to training")
flags.DEFINE_float('ema_decay', 0.9999, help="ema decay rate")
flags.DEFINE_bool('parallel', False, help='multi gpu training')
flags.DEFINE_string('pretrained_dir', './logs/iDDPM_CIFAR10_EPS/models/ckpt_1_450000.pt', help='log directory')

# Logging & Sampling
flags.DEFINE_string('logdir', './logs/iDDPM_CIFAR10_EPS', help='log directory')
flags.DEFINE_integer('sample_size', 64, "sampling size of images")
flags.DEFINE_integer('sample_step', 10000, help='frequency of sampling')
flags.DEFINE_integer('sample_steps', 1000, help='Sampling steps for generation stage')
flags.DEFINE_bool('covmean', False,help='whether use cov mean to sample')
# Evaluation
flags.DEFINE_integer('save_step', 50000, help='frequency of saving checkpoints, 0 to disable during training')
flags.DEFINE_integer('eval_step', 0, help='frequency of evaluating model, 0 to disable during training')
flags.DEFINE_integer('num_images', 50000, help='the number of generated images for evaluation')
flags.DEFINE_bool('fid_use_torch', False, help='calculate IS and FID on gpu')
flags.DEFINE_bool('time_shift', False, help='whether the noised data is from t=1')
flags.DEFINE_string('fid_cache', './stats/cifar10.train.npz', help='FID cache')
# Model Dir
flags.DEFINE_string('eps1_dir', './logs/iDDPM_CIFAR10_EPS/models/ckpt_1_300000.pt', help='eps1 model log directory')
flags.DEFINE_string('eps2_dir', './logs/iDDPM_CIFAR10_EPS2/models/ckpt_2_300000.pt', help='eps2 model log directory')
flags.DEFINE_string('eps3_dir', './logs/iDDPM_CIFAR10_complex_EPS3/models/ckpt_3_300000.pt', help='eps3 model log directory')
flags.DEFINE_string('eps4_dir', './logs/iDDPM_CIFAR10_complex_EPS4/models/ckpt_4_300000.pt', help='eps4 model log directory')

device = torch.device('cuda:0')

import numpy as np
import torch.nn.functional as F
import torch
import math


def bipartition(ts):
    if ts.dim() == 4:  # bs * 2c * w * w
        assert ts.size(1) % 2 == 0
        c = ts.size(1) // 2
        return ts.split(c, dim=1)
    else:
        raise NotImplementedError


def stp(s, ts: torch.Tensor):  # scalar tensor product
    if isinstance(s, np.ndarray):
        s = torch.from_numpy(s).type_as(ts)
    extra_dims = (1,) * (ts.dim() - 1)
    return s.view(-1, *extra_dims) * ts


def sos(a, start_dim=1):  # sum of square
    return a.pow(2).flatten(start_dim=start_dim).sum(dim=-1)


def mos(a, start_dim=1):  # mean of square
    return a.pow(2).flatten(start_dim=start_dim).mean(dim=-1)


def inner_product(a, b, start_dim=1):
    return (a * b).flatten(start_dim=start_dim).sum(dim=-1)


def mean_flat(tensor, keepdim=False):
    return tensor.mean(dim=list(range(1, len(tensor.shape))), keepdim=keepdim)


def duplicate(tensor, *size):
    return tensor.unsqueeze(dim=0).expand(*size, *tensor.shape)


def unsqueeze_like(tensor, template, start="left"):
    if start == "left":
        tensor_dim = tensor.dim()
        template_dim = template.dim()
        assert tensor.shape == template.shape[:tensor_dim]
        return tensor.view(*tensor.shape, *([1] * (template_dim - tensor_dim)))
    elif start == "right":
        tensor_dim = tensor.dim()
        template_dim = template.dim()
        assert tensor.shape == template.shape[-tensor_dim:]
        return tensor.view(*([1] * (template_dim - tensor_dim)), *tensor.shape)
    else:
        raise ValueError


def logsumexp(tensor, dim, keepdim=False):
    # the logsumexp of pytorch is not stable!
    tensor_max, _ = tensor.max(dim=dim, keepdim=True)
    ret = (tensor - tensor_max).exp().sum(dim=dim, keepdim=True).log() + tensor_max
    if not keepdim:
        ret.squeeze_(dim=dim)
    return ret


def log(x):
    if isinstance(x, torch.Tensor):
        return x.log()
    elif isinstance(x, np.ndarray):
        return np.log(x)
    else:
        return math.log(x)


def log_normal(x, mu, var):  # element-wise
    rvar = 1. / var
    return -0.5 * (x - mu) ** 2 * rvar - 0.5 * log(2 * np.pi * var)


def approx_standard_normal_cdf(x):
    """
    A fast approximation of the cumulative distribution function of the
    standard normal.
    """
    return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3))))


def log_discretized_normal(x, mu, var):  # element-wise
    centered_x = x - mu
    std = var ** 0.5
    left = (centered_x - 1. / 255) / std
    right = (centered_x + 1. / 255) / std

    cdf_right = approx_standard_normal_cdf(right)
    cdf_left = approx_standard_normal_cdf(left)
    cdf_delta = cdf_right - cdf_left

    return torch.where(
        x < -0.999,
        cdf_right.clamp(min=1e-12).log(),
        torch.where(x > 0.999, (1. - cdf_left).clamp(min=1e-12).log(), cdf_delta.clamp(min=1e-12).log()),
    )


def binary_cross_entropy_with_logits(logits, inputs):
    r""" -inputs * log (sigmoid(logits)) - (1 - inputs) * log (1 - sigmoid(logits)) element wise
        with automatically expand dimensions
    """
    if inputs.dim() < logits.dim():
        inputs = inputs.expand_as(logits)
    else:
        logits = logits.expand_as(inputs)
    return F.binary_cross_entropy_with_logits(logits, inputs, reduction="none")


def log_bernoulli(inputs, logits, n_data_dim):
    return -binary_cross_entropy_with_logits(logits, inputs).flatten(-n_data_dim).sum(dim=-1)


def kl_between_normal(mu_0, var_0, mu_1, var_1):  # element-wise
    tensor = None
    for obj in (mu_0, var_0, mu_1, var_1):
        if isinstance(obj, torch.Tensor):
            tensor = obj
            break
    assert tensor is not None

    var_0, var_1 = [
        x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
        for x in (var_0, var_1)
    ]

    return 0.5 * (var_0 / var_1 + (mu_0 - mu_1).pow(2) / var_1 + var_1.log() - var_0.log() - 1.)


def approx_standard_normal_cdf(x):
    """
    A fast approximation of the cumulative distribution function of the
    standard normal.
    """
    return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3))))

def log_prob_gaussian(x,mu,sigma2):
    sigma = torch.sqrt(sigma2)
    dist = Normal(mu, sigma)
    logp = dist.log_prob(x)
    #return logp.view(logp.size()[0],-1).sum(dim=1)
    #return logp.view(logp.size()[0],-1).sum(dim=1)
    #logging.info(logp.size())
    return logp

def log_discretized_normal(x, mu, var):  # element-wise
    centered_x = x - mu
    #logging.info(x.size())
    #logging.info(centered_x.size())
    std = var ** 0.5
    left = (centered_x - 1. / 255) / std
    right = (centered_x + 1. / 255) / std

    cdf_right = approx_standard_normal_cdf(right)
    cdf_left = approx_standard_normal_cdf(left)
    cdf_delta = cdf_right - cdf_left

    return torch.where(
        x < -0.999,
        cdf_right.clamp(min=1e-12).log(),
        torch.where(x > 0.999, (1. - cdf_left).clamp(min=1e-12).log(), cdf_delta.clamp(min=1e-12).log()),
    )

def log_prob_mixturegaussian(x,mu1,mu2,sigma1_2,sigma2_2):
    sigma1= sigma1_2.pow(0.5)
    sigma2= sigma2_2.pow(0.5)

    dist1 = Normal(mu1, sigma1)
    dist2 = Normal(mu2, sigma2)
    logp1 = dist1.log_prob(x)
    logp2 = dist2.log_prob(x)
    #logp1 = logp1.view(logp1.size()[0],-1).sum(dim=1).unsqueeze(dim=1)
    logp1 = logp1.unsqueeze(dim=-1)
    logp2 = logp2.unsqueeze(dim=-1)
    #logp2 = logp2.view(logp1.size()[0],-1).sum(dim=1).unsqueeze(dim=1)
    pi_logp1 = logp1.repeat_interleave(3, dim=-1)
    pi_logp2 = logp2.repeat_interleave(7, dim=-1)
    pi_logp = torch.cat([pi_logp1,pi_logp2],dim=-1)
    f_logp  = torch.logsumexp(pi_logp,dim=-1)
    #f_logp  = torch.logmeanexp(pi_logp,dim=1)
    return f_logp-torch.log(torch.tensor(10.))

def log_prob_disctre_mixturegaussian(x,mu1,mu2,sigma1_2,sigma2_2):
    logp1 = log_discretized_normal(x, mu1, sigma1_2)
    logp2 = log_discretized_normal(x, mu2, sigma2_2)
    #logp1 = logp1.view(logp1.size()[0],-1).sum(dim=1).unsqueeze(dim=1)
    logp1 = logp1.unsqueeze(dim=-1)
    logp2 = logp2.unsqueeze(dim=-1)
    #logp2 = logp2.view(logp1.size()[0],-1).sum(dim=1).unsqueeze(dim=1)
    pi_logp1 = logp1.repeat_interleave(3, dim=-1)
    pi_logp2 = logp2.repeat_interleave(7, dim=-1)
    pi_logp = torch.cat([pi_logp1,pi_logp2],dim=-1)
    f_logp  = torch.logsumexp(pi_logp,dim=-1)
    #f_logp  = torch.logmeanexp(pi_logp,dim=1)
    return f_logp-torch.log(torch.tensor(10.))

def extract(v, t, x_shape,ratio=None):
    """
    Extract some coefficients at specified timesteps, then reshape to
    [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
    """
    #if ratio:
    #    out = torch.ones(size=(200,1)).squeeze()
    #    for ele in range(ratio):
    #        out *= torch.gather(v, index=t-ele, dim=0).float()
    out = torch.gather(v, index=t, dim=0).float()
    #print
    return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))

class TemporaryGrad(object):
    def __enter__(self):
        self.prev = torch.is_grad_enabled()
        torch.set_grad_enabled(True)

    def __exit__(self, exc_type, exc_value, traceback) -> None:
        torch.set_grad_enabled(self.prev)

def solve_gmm(mean,cov,ske,kur,gt,timestep):
    device= mean.device
    x0 = torch.unsqueeze((mean),dim=0)
    #x0 = torch.unsqueeze((torch.randn(size=mean.size()).to(device)),dim=0)
    #x1 = torch.unsqueeze((torch.randn(size=mean.size()).to(device)),dim=0)
    x1 = torch.unsqueeze((mean-1e-3),dim=0)
    #beta2 = torch.unsqueeze(((torch.ones(size=mean.size()).to(device))*(cov/gt.mean().item())),dim=0)
    beta = torch.unsqueeze((torch.ones(size=mean.size()).to(device)*0.999),dim=0)
    #x     = torch.cat([x0,x1,beta1,beta2],axis=0)
    #x0,x1,beta = solve_analytic(mean,cov,ske)
    x     = torch.cat([x0,x1,beta],axis=0)
    cov_g = gt
    def loss_f(tensor):
        #if solve_type =='pi':
        x0, x1, beta = tensor[0,...], tensor[1,...],tensor[2,...]
        #x0, x1, beta1 = tensor[0,...], tensor[1,...],tensor[2,...]
        beta = torch.clamp(beta, 0.1, 1.2)
        #beta2 = 1
        pi = 1/3
        #E0 = (torch.abs(pi*x0 + (1-pi)*x1 - mean)).max()
        #E1 = (torch.abs(pi*(x0**2+cov_g*beta)+(1-pi)*(x1**2+cov_g*beta) - (mean**2+cov))).max()
        #E2 = (torch.abs(pi*(x0**3+3*x0*cov_g*beta)+(1-pi)*(x1**3+3*x1*cov_g*beta) - ske)).max()
        E0 = (pi*x0 + (1-pi)*x1 - mean).pow(2)
        E1 = (pi*(x0**2+cov_g*beta)+(1-pi)*(x1**2+cov_g*beta) - (mean**2+cov)).pow(2)
        E2 = (pi*(x0**3+3*x0*cov_g*beta)+(1-pi)*(x1**3+3*x1*cov_g*beta) - ske).pow(2)
        if kur is not None:
            E3 = (pi*(x0**4+6*x0**2*cov_g+3*(cov_g)**2)+(1-pi)*(x1**4+6*x1**2*cov_g*beta+3*(cov_g*beta)**2) - kur).pow(2)
        else:
            E3 = 0
        return ((E0+E1+E2)).mean(),E2.mean()
    #def warmup_lr(step):
    #    return min(step, 10) / 10
    warm_up    = 18
    iterations = 25
    #lr     = 0.008-0.005*timestep/1000
    lr     = 0.01
    min_lr = 0.001

    warm_up_with_cosine_lr = lambda iter: (iter) / warm_up if iter <= warm_up \
        else max(0.5 * ( math.cos((iter - warm_up) /(iterations - warm_up) * math.pi) + 1), 
        min_lr / lr)

    with TemporaryGrad():
        #optimizer_solve = torch.optim.Adam([x],lr=lr,betas=(0.95, 0.95))
        #optimizer_solve = torch.optim.RMSprop([x],lr=lr,alpha=0.9)
        #optimizer_solve = torch.optim.Adagrad([x],lr=lr,weight_decay=1e-4)
        #optimizer_solve = torch.optim.AdamW([x],lr=lr,weight_decay=1e-4)
        optimizer_solve = Adan([x],lr=lr,betas=(0.9,0.92,0.95))

        #sched = torch.optim.lr_scheduler.LambdaLR(optimizer_solve, lr_lambda=warmup_lr)
        #scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, warm_up_with_cosine_lr)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer_solve, warm_up_with_cosine_lr)
        #for step in range(200):
        pred_0,pre_E2 = loss_f(x)
        for step in range(iterations):
            x.requires_grad = True
            pred,E2 = loss_f(x)
            optimizer_solve.zero_grad()
            pred.backward()
            optimizer_solve.step()
            scheduler.step()
            #logging.info(pred)
            #x[3,...] = torch.clip(x[3,...],0,1)
        # return mu1 mu2 sigma1 sigma2
    #logging.info('mean optimize {0},E2 {1}'.format(pred/pred_0,E2/pre_E2))
    return x[0,...], x[1,...],torch.clamp(x[2,...], 0.1, 1.2)

def gaussian_gmm(mean,cov,ske,kur,gt):
    device= mean.device
    x0 = torch.unsqueeze((mean),dim=0)
    #x0 = torch.unsqueeze((torch.randn(size=mean.size()).to(device)),dim=0)
    #x1 = torch.unsqueeze((torch.randn(size=mean.size()).to(device)),dim=0)
    #beta2 = torch.unsqueeze(((torch.ones(size=mean.size()).to(device))*(cov/gt.mean().item())),dim=0)
    x = torch.unsqueeze((torch.ones(size=mean.size()).to(device)*0.9),dim=0)
    #x     = torch.cat([x0,x1,beta1,beta2],axis=0)
    #x0,x1,beta = solve_analytic(mean,cov,ske)
    #x     = torch.cat([x0,beta],axis=0)
    cov_g = gt
    def loss_f(tensor):
        #if solve_type =='pi':
        beta = tensor[0,...]
        #x0, x1, beta1 = tensor[0,...], tensor[1,...],tensor[2,...]
        beta = torch.clamp(beta, 0.1, 1.2)
        #beta2 = 1
        #E0 = (torch.abs(pi*x0 + (1-pi)*x1 - mean)).max()
        #E1 = (torch.abs(pi*(x0**2+cov_g*beta)+(1-pi)*(x1**2+cov_g*beta) - (mean**2+cov))).max()
        #E2 = (torch.abs(pi*(x0**3+3*x0*cov_g*beta)+(1-pi)*(x1**3+3*x1*cov_g*beta) - ske)).max()
        E0 = (x0 - mean).pow(2)
        E1 = (x0**2+cov_g*beta - (mean**2+cov)).pow(2)
        E2 = (x0**3+3*x0*cov_g*beta - ske).pow(2)
        return ((E1+E1.mean()/E2.mean()*E2)).mean(),E2.mean()
    #def warmup_lr(step):
    #    return min(step, 10) / 10
    warm_up    = 6
    iterations = 200
    lr     = 0.02
    min_lr = 0.0001

    warm_up_with_cosine_lr = lambda iter: (iter) / warm_up if iter <= warm_up \
        else max(0.5 * ( math.cos((iter - warm_up) /(iterations - warm_up) * math.pi) + 1), 
        min_lr / lr)

    with TemporaryGrad():
        #optimizer_solve = torch.optim.Adam([x],lr=lr,betas=(0.95, 0.95))
        #optimizer_solve = torch.optim.RMSprop([x],lr=lr,alpha=0.9)
        #optimizer_solve = torch.optim.Adagrad([x],lr=lr,weight_decay=1e-4)
        #optimizer_solve = torch.optim.AdamW([x],lr=lr,weight_decay=1e-4)
        optimizer_solve = Adan([x],lr=lr,betas=(0.9,0.92,0.95))

        #sched = torch.optim.lr_scheduler.LambdaLR(optimizer_solve, lr_lambda=warmup_lr)
        #scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, warm_up_with_cosine_lr)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer_solve, warm_up_with_cosine_lr)
        #for step in range(200):
        pred_0,pre_E2 = loss_f(x)
        for step in range(iterations):
            x.requires_grad = True
            pred,E2 = loss_f(x)
            optimizer_solve.zero_grad()
            pred.backward()
            optimizer_solve.step()
            scheduler.step()
            #logging.info(pred)
            #x[3,...] = torch.clip(x[3,...],0,1)
        # return mu1 mu2 sigma1 sigma2
    logging.info('mean optimize {0},E2 {1}'.format(pred/pred_0,E2/pre_E2))
    return torch.clamp(x[0,...], 0.1, 1.2)

def solve_analytic(mean,cov,ske):
    Z12 = mean.pow(2)+cov
    #Z12 = 
    Z11 = mean
    Z13 = ske
    #logging.info(2*Z11.pow(3)-3*Z11*Z12+Z13)
    #logging.info(Z13)
    com_p = torch.where(2*Z11.pow(3)-3*Z11*Z12+Z13>=0,2*Z11.pow(3)-3*Z11*Z12+Z13,0)
    #logging.info((com_p).min())
    #logging.info((com_p).pow(1/3))
    mean1 = mean + 1.5874*(com_p).pow(1/3)
    mean2 = 0.5 * (2*Z11-1.5874*(com_p).pow(1/3))
    cov_m = (-Z11.pow(2)+Z12-1.25992*(com_p).pow(2/3))
    cov_m2 = torch.where(cov_m>=0,cov_m,cov)
    mean1 = torch.where(cov_m>=0,mean1,mean)
    mean2 = torch.where(cov_m>=0,mean2,mean)

    bar1 = 0.1
    bar2 = 0.5

    cov_m3 = torch.where(cov_m2<=bar1*cov,cov,cov_m2)
    mean1 = torch.where(cov_m2<=bar1*cov,mean,mean1)
    mean2 = torch.where(cov_m2<=bar1*cov,mean,mean2)
    mean1f = torch.where(torch.abs((mean1-mean)/mean)>=1-bar2,mean,mean1)
    mean2 = torch.where(torch.abs((mean1-mean)/mean)>=1-bar2,mean,mean2)

    mean2f = torch.where(torch.abs((mean2-mean)/mean)>=1-bar2,mean,mean2)
    mean1f = torch.where(torch.abs((mean2-mean)/mean)>=1-bar2,mean,mean1f)
    return mean1f,mean2f,cov_m3

class likelihood(nn.Module):
    def __init__(self, eps1_model,eps2_model,eps3_model,eps2_nll_model, eps3_nll_model,beta_1, beta_T, T,img_size=32,
                 sample_type='ddpm',time_shift=False,noise_schedule='linear',covmean=False):
        assert sample_type in ['ddpm', 'analyticdpm', 'gmddpm','mean_network']
        super().__init__()
        self.model      = eps1_model
        self.cov_model  = eps2_model
        self.eps3_model = eps3_model
        self.cov_model_nll = eps2_nll_model
        self.eps3_nll_model= eps3_nll_model
        self.T = T
        self.covmean = covmean
        self.total_T = 1000
        if self.total_T % self.T  ==0:
            self.ratio = int(self.total_T/self.T)
        else:
            self.ratio = int(self.total_T/self.T)+1

        self.t_list = [max(self.total_T-1-self.ratio*x,0) for x in range(T)]
        #if self.t_list[-1] != 0:
        self.t_list.append(0)
        logging.info(self.t_list)

        self.img_size  = img_size
        self.sample_type = sample_type
        self.time_shift  = time_shift

        if noise_schedule=='linear':
            betas0 = torch.tensor([0.])
            betas1 = torch.linspace(beta_1, beta_T, self.total_T).double()
            self.register_buffer(
                'betas', torch.cat([betas0,betas1],dim=0))
            alphas = 1. - self.betas
            alphas_bar = torch.cumprod(alphas, dim=0)
            # calculations for diffusion q(x_t | x_{t-1}) and others
        else:
            logging.info(noise_schedule)
            g = lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
            betas = [0.]
            for i in range(self.total_T):
                t1 = i / self.total_T
                t2 = (i + 1) / self.total_T
                betas.append(min(1 - g(t2) / g(t1), 0.999))
            betas = torch.tensor(np.array(betas))
            self.register_buffer(
                'betas', betas)
            alphas= 1-betas
            alphas_bar = torch.cumprod(alphas, dim=0)
            alphas = alphas

            logging.info(alphas_bar)
            #logging.info(alphas_bar.size())
        alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:self.total_T+1]
        logging.info(alphas_bar_prev.size())
        self.register_buffer(
            'sqrt_alphas_bar', torch.sqrt(alphas_bar))
        self.register_buffer(
            'one_minus_alphas_bar', (1.- alphas_bar))
        self.register_buffer(
            'sqrt_one_minus_alphas_bar', torch.sqrt(1.- alphas_bar))

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer(
            'sqrt_recip_alphas_bar', torch.sqrt(1. / alphas_bar))
        self.register_buffer(
            'sqrt_recipm1_alphas_bar', torch.sqrt(1. / alphas_bar - 1))

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        self.register_buffer(
            'posterior_var',
            self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))
        
        # below: log calculation clipped because the posterior variance is 0 at
        # the beginning of the diffusion chain
        self.register_buffer(
            'posterior_log_var_clipped',
            torch.log(
                torch.cat([self.posterior_var[1:2], self.posterior_var[1:]])))
        
        self.register_buffer(
            'posterior_mean_coef1',
            torch.sqrt(alphas_bar_prev) * self.betas / (1. - alphas_bar))
        self.register_buffer(
            'posterior_mean_coef2',
            torch.sqrt(alphas) * (1. - alphas_bar_prev) / (1. - alphas_bar))

    def q_mean_variance(self, x_0, x_t, t):
        """
        Compute the mean and variance of the diffusion posterior
        q(x_{t-1} | x_t, x_0)
        """
        assert x_0.shape == x_t.shape
        posterior_mean = (
            extract(self.posterior_mean_coef1, t, x_t.shape) * x_0 +
            extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_log_var_clipped = extract(
            self.posterior_log_var_clipped, t, x_t.shape)
        return posterior_mean, posterior_log_var_clipped

    # use eps to estimate one order moment
    def predict_xpre_from_eps(self, x_t, t, eps):
        assert x_t.shape == eps.shape
        a_t = extract(self.sqrt_alphas_bar, t, x_t.shape)
        if (t-self.ratio)[0]>=0:
            a_s  = extract(self.sqrt_alphas_bar, t-self.ratio, x_t.shape)
            a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)/extract(self.sqrt_alphas_bar, t-self.ratio, x_t.shape)
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t-self.ratio, x_t.shape))
            sigma_t = torch.sqrt(extract(self.one_minus_alphas_bar, t, x_t.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2
        else:
            a_s  = extract(self.sqrt_alphas_bar, t-t, x_t.shape)
            a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)/extract(self.sqrt_alphas_bar, t-t, x_t.shape)
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t-t, x_t.shape))
            sigma_t = torch.sqrt(extract(self.one_minus_alphas_bar, t, x_t.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2
        mean_x0 = (x_t - sigma_t * eps)/a_t
        self.statistics['xt_mean'] = x_t.mean().item()
        self.statistics['eps_mean'] = eps.mean().item()
        self.statistics['unclip mean_x0_mean'] = mean_x0.mean().item()
        mean_x0 = mean_x0.clamp(-1.,1.)
        self.statistics['clip mean_x0_mean'] = mean_x0.mean().item()
        mean_xs = a_ts*sigma_s.pow(2)/(sigma_t.pow(2)) * x_t + a_s*beta_ts/(sigma_t.pow(2)) * mean_x0
        mean_xs = mean_xs.clamp(-100.,100.)
        self.statistics['clip mean_xs_max'] = mean_xs.max().item()
        #logging.info(a_ts.mean())
        return mean_xs,mean_x0

    # use eps and eps2 to estimate one order moment
    def predict_xpre_cov_from_eps(self, x_t, t, eps):
        if self.time_shift:
            eps2 = self.cov_model(x_t, t+1)
        else:
            eps2 = self.cov_model(x_t, t+1)
        a_t  = extract(self.sqrt_alphas_bar, t, x_t.shape)

        if (t-self.ratio)[0]>=0:
            a_s  = extract(self.sqrt_alphas_bar, t-self.ratio, x_t.shape)
            a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)/extract(self.sqrt_alphas_bar, t-self.ratio, x_t.shape)
            #a_ts = extract(self.sqrt_recip_alphas_bar, t-self.ratio, x_t.shape)/extract(self.sqrt_recip_alphas_bar, t, x_t.shape)
            sigma_t = torch.sqrt(extract(self.one_minus_alphas_bar, t, x_t.shape))
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t-self.ratio, x_t.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2
        else:
            # \alpha_{t|s}
            a_s  = extract(self.sqrt_alphas_bar, t-t, x_t.shape)
            a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)/extract(self.sqrt_alphas_bar, t-t, x_t.shape)
            #a_ts = extract(self.sqrt_recip_alphas_bar, t-t, x_t.shape)/extract(self.sqrt_recip_alphas_bar, t, x_t.shape)
            sigma_t = torch.sqrt(extract(self.one_minus_alphas_bar, t, x_t.shape))
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t-t, x_t.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2

        sigma2_small = (sigma_s**2*beta_ts)/(sigma_t**2)
        cov_x0_pred = sigma_t.pow(2)/a_t.pow(2) * (eps2-eps.pow(2)) 
        self.statistics['unclip cov_x0_mean'] = cov_x0_pred.mean().item()
        cov_x0_pred = cov_x0_pred.clamp(0., 1.)
        self.statistics['clip cov_x0_mean'] = cov_x0_pred.mean().item()
        offset = a_s.pow(2)*beta_ts.pow(2)/sigma_t.pow(4) * cov_x0_pred
        self.statistics['offset'] = offset.mean().item()
        self.statistics['offset_max'] = offset.max().item()
        self.statistics['sigma2_small'] = sigma2_small.mean().item()
        model_var  = sigma2_small + offset
        model_var  = model_var.clamp(0., 1.)
        return model_var,eps2,cov_x0_pred
    
    def predict_xpre_cov_from_eps_nll(self, x_t, t,eps):
        if self.time_shift:
            eps2 = self.cov_model(x_t, t+1-1)
            epsc_pred = self.cov_model_nll(x_t, t+1-1)
        else:
            eps2 = self.cov_model(x_t, t+1-1)
            epsc_pred = self.cov_model_nll(x_t, t+1-1)
        a_t  = extract(self.sqrt_alphas_bar, t, x_t.shape)

        if (t-self.ratio)[0]>=0:
            a_s  = extract(self.sqrt_alphas_bar, t-self.ratio, x_t.shape)
            a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)/extract(self.sqrt_alphas_bar, t-self.ratio, x_t.shape)
            #a_ts = extract(self.sqrt_recip_alphas_bar, t-self.ratio, x_t.shape)/extract(self.sqrt_recip_alphas_bar, t, x_t.shape)
            sigma_t = torch.sqrt(extract(self.one_minus_alphas_bar, t, x_t.shape))
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t-self.ratio, x_t.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2
        else:
            # \alpha_{t|s}
            a_s  = extract(self.sqrt_alphas_bar, t-t, x_t.shape)
            a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)/extract(self.sqrt_alphas_bar, t-t, x_t.shape)
            #a_ts = extract(self.sqrt_recip_alphas_bar, t-t, x_t.shape)/extract(self.sqrt_recip_alphas_bar, t, x_t.shape)
            sigma_t = torch.sqrt(extract(self.one_minus_alphas_bar, t, x_t.shape))
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t-t, x_t.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2

        sigma2_small = (sigma_s**2*beta_ts)/(sigma_t**2)
        #cov_x0_pred = sigma_t.pow(2)/a_t.pow(2) * epsc_pred
        cov_x0_pred = sigma_t.pow(2)/a_t.pow(2) * (eps2-eps.pow(2))
        self.statistics['unclip cov_x0_mean'] = cov_x0_pred.mean().item()
        cov_x0_pred = cov_x0_pred.clamp(0., 1.)

        self.statistics['clip cov_x0_mean'] = cov_x0_pred.mean().item()
        offset = a_s.pow(2)*beta_ts.pow(2)/sigma_t.pow(4) * cov_x0_pred
        self.statistics['offset'] = offset.mean().item()
        self.statistics['offset_max'] = offset.max().item()
        self.statistics['sigma2_small'] = sigma2_small.mean().item()
        model_var  = sigma2_small + offset
        model_var  = model_var.clamp(0., 1.)
        return model_var,eps2,cov_x0_pred
    
    def ddpm_cov(self, x_t, t,big=True):
        sigma_t = torch.sqrt(extract(self.one_minus_alphas_bar, t, x_t.shape))
        if (t-self.ratio)[0]>0:
            # \alpha_{t|s}
            a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)/extract(self.sqrt_alphas_bar, t-self.ratio, x_t.shape)
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t-self.ratio, x_t.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2
            model_var1 = (sigma_s**2*beta_ts)/(sigma_t**2)
        else:
            # \alpha_{t|s}
            logging.info('first steps')
            a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)/extract(self.sqrt_alphas_bar, t-t, x_t.shape)
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t-t, x_t.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2
            model_var1 = beta_ts
            # 1-a_t-a_t/a_s*(1-a_s)=1-a_t-a_t/a_s+a_t=1-a_t/a_s

        self.statistics['sigma2_small'] = model_var1.mean().item()
        if big:
            #logging.info(beta_ts-1+a_ts.pow(2))
            return beta_ts
        #model_var1 = torch.where(model_var1<=1e-20,beta_ts,model_var1)
        return model_var1

    # use eps and eps2 and eps3 to estimate one order moment
    def predict_xpre_3moment_from_eps(self, x_t, t, eps, eps2, mean,var):
        if self.time_shift:
            eps3 = self.eps3_model(x_t, t-1)
        else:
            eps3 = self.eps3_model(x_t, t-1)

        sigma_t = torch.sqrt(extract(self.one_minus_alphas_bar, t, x_t.shape))
        a_t     = extract(self.sqrt_alphas_bar, t, x_t.shape)
        if (t-self.ratio)[0]>=0:
            # \alpha_{t|s}
            a_s  = extract(self.sqrt_alphas_bar, t-self.ratio, x_t.shape)
            a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)/extract(self.sqrt_alphas_bar, t-self.ratio, x_t.shape)
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t-self.ratio, x_t.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2
        else:
            # \alpha_{t|s}
            a_s  = extract(self.sqrt_alphas_bar, t-t, x_t.shape)
            a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)/extract(self.sqrt_alphas_bar, t-t, x_t.shape)
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t-t, x_t.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2

        mean_x0 = (x_t - sigma_t * eps)/a_t
        mean_x0 = mean_x0.clamp(-1., 1.)
        eps_new = (x_t-mean_x0*a_t)/sigma_t
        #twom_x0 = 1/(a_t.pow(2))*(x_t.pow(2)+sigma_t.pow(2)*eps2-2*x_t*sigma_t*eps)
        twom_x0 = var + mean_x0.pow(2)
        twom_x0 = twom_x0.clamp(0., 1.)
        #skew_x01 = 1/(a_t.pow(3))*(x_t.pow(3) + self.eps3_nll_model(x_t, t))
        #skew_x02 = 1/(a_t.pow(3))*(x_t.pow(3) - sigma_t.pow(3)*eps3 - 3*x_t.pow(2)*sigma_t*eps_new + 3*x_t*sigma_t.pow(2)*eps2)
        skew_x02 = 1/(a_t.pow(3))*(x_t.pow(3) - sigma_t.pow(3)*eps3 - 3*x_t.pow(2)*sigma_t*eps + 3*x_t*sigma_t.pow(2)*eps2)
        skew_x03 = mean_x0.pow(3)+3*mean_x0*var
        #skew_x0 = (skew_x03+skew_x02)/2
        skew_x0 = skew_x02
        
        self.statistics['unclip_x0_skew'] = skew_x0.mean().item()
        skew_x0 = torch.where(torch.abs(skew_x0)<=torch.abs(mean_x0),skew_x0,mean_x0)
        skew_x0 = skew_x0.clamp(-1., 1.)
        self.statistics['clip_x0_skew'] = skew_x0.mean().item()
        sigma2_small = (sigma_s**2*beta_ts)/(sigma_t**2)

        skew_xs_part1 = (a_ts*sigma_s.pow(2)/(sigma_t.pow(2)) * x_t).pow(3)+\
            3*(a_ts*sigma_s.pow(2)/(sigma_t.pow(2)) * x_t).pow(2)*(a_s*beta_ts/sigma_t.pow(2))*mean_x0 +\
            3*(a_ts*sigma_s.pow(2)/(sigma_t.pow(2)) * x_t)*(a_s*beta_ts/sigma_t.pow(2)).pow(2)*twom_x0 +\
            (a_s*beta_ts/sigma_t.pow(2)).pow(3)*skew_x0
        skew_xs_part2 = 3*sigma2_small*(a_ts*sigma_s.pow(2)/(sigma_t.pow(2)) * x_t + a_s*beta_ts/(sigma_t.pow(2)) * mean_x0)
        skew_xs  = skew_xs_part1+skew_xs_part2
        #part1 = 1/(a_ts**3) * ((x_t**3) - 3*(x_t**2)*eps*(beta_ts/sigma_t)+3*(x_t)*eps2*(beta_ts**2/sigma_t**2)-(beta_ts/sigma_t)**3*eps3)
        #part2 = 3*(sigma_s**2*beta_ts)/(sigma_t**2) * (1/a_ts) * (x_t-beta_ts/sigma_t*eps)
        #third_moment = part1 + part2 
        self.statistics['clip_xs_skew'] = skew_xs.mean().item()
        return skew_xs,eps3

    def p_mean_variance(self, x_t, t):
        if self.time_shift:
            eps  = self.model(x_t, t)
        else:
            eps  = self.model(x_t, t-1)

        mean,mean_x0    = self.predict_xpre_from_eps(x_t, t, eps=eps)
        #cov,eps2,cov_x0_pred = self.predict_xpre_cov_from_eps(x_t, t, eps)
        cov,eps2,cov_x0_pred  = self.predict_xpre_cov_from_eps_nll(x_t, t,eps)
        skeness,eps3  = self.predict_xpre_3moment_from_eps(x_t, t, eps,eps2,mean_x0,cov_x0_pred)
        gt_var   = cov
        return mean,cov,skeness,gt_var

    def forward(self,x0,t_count,s_count):
        self.statistics = {}
        x_0 = x0
        if t_count<=s_count:
            return 1e10,1e10,1e10
        if s_count==0:
            logging.info('from {0} to 0'.format(t_count))
            t = x0.new_ones([x0.shape[0], ], dtype=torch.long)  * 0
            t_1 = x_0.new_ones([x_0.shape[0], ], dtype=torch.long)  * t_count

            noise = torch.randn_like(x0).to(x0.device)
            a_s  = extract(self.sqrt_alphas_bar, t, x_0.shape)
            a_ts = extract(self.sqrt_alphas_bar, t_1, x_0.shape)/extract(self.sqrt_alphas_bar, t, x_0.shape)
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t, x_0.shape))
            sigma_t = torch.sqrt(extract(self.one_minus_alphas_bar, t_1, x_0.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2
            x_1     = (a_ts * x_0 + torch.sqrt(beta_ts) * noise)
            self.ratio = t_count
            mean,cov,skeness,sigma2_small = self.p_mean_variance(x_t=x_1, t=t_1)
            pre_cov = sigma2_small
            fmoment = None
            mean1,mean2,beta = solve_gmm(mean,cov,skeness,fmoment,pre_cov,0)

            # Gaussian ************************
            if self.covmean:
                cov = mean1/mean1  * cov.mean()
            #if self.sample_type=='ddpm':
            cov_ddpm =  self.ddpm_cov(x_t=x_1, t=t_1)
            p1_gaussian_ddpm = -log_discretized_normal(x_0,mean,cov_ddpm)
            #logging.info(cov.size())
            p1_gaussian = -log_discretized_normal(x_0,mean,cov)

            logging.info('the likelihood of p(x_0|x_1) of ddpm is {0}'.format(p1_gaussian_ddpm.mean()/math.log(2.)))
            logging.info('the likelihood of p(x_0|x_1) of adpm is {0}'.format(p1_gaussian.mean()/math.log(2.)))


            # GMDDPM ************************
            var = beta*pre_cov
            var = mean1/mean1 * var.mean()
            p1_mgaussian=  -log_discretized_normal(x_0,mean,cov)

            noise = torch.randn_like(x0).to(x0.device)
            ## ********************************************************************************************************************
            #-log( p(x_T) / (q(x_T|x_0)) )
            T   = x_0.new_ones([x_0.shape[0], ], dtype=torch.long)  * (self.t_list[0]+1)
            a_s  = extract(self.sqrt_alphas_bar, t, x_0.shape)
            a_ts = extract(self.sqrt_alphas_bar, T, x_0.shape)/extract(self.sqrt_alphas_bar, t, x_0.shape)
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t, x_0.shape))
            sigma_t = torch.sqrt(extract(self.one_minus_alphas_bar, T, x_0.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2

            #if self.sample_type=='gmddpm':
            p1_gaussian_m = -log_prob_disctre_mixturegaussian(x_0,mean1,mean2,var,var)
            if p1_gaussian_m.mean()<p1_gaussian.mean():
                logging.info('Gaussian Mixture outperforms in the first steps')
            logging.info('the likelihood of p(x_0|x_1) of gdpm is {0}'.format(p1_gaussian_m.mean()/math.log(2.)))
            #else:
            #    p1_gaussian_m = p1_gaussian
                #return p1_gaussian_m.mean()/math.log(2.)
            log_error = p1_gaussian_m-p1_gaussian
            logging.info('\n')
            return p1_gaussian_ddpm.mean().cpu().detach().numpy()/math.log(2.),p1_gaussian.mean().cpu().detach().numpy()/math.log(2.),log_error.mean().cpu().detach().numpy()/math.log(2.)
        else:
            logging.info('from {0} to {1}'.format(t_count,s_count))
            K = 1
            inter_kl_log_g_list = []
            # Logging about the gmddpm
            inter_kl_log_mg_list = []
            inter_kl_log_ddpm_list = []

            for mc in range(K):
                # Record the ratio
                self.ratio = t_count-s_count
                ###########################################
                # Generate x_s
                noise = torch.randn_like(x0).to(x0.device)
                s = x_0.new_ones([x_0.shape[0], ], dtype=torch.long)  * s_count
                a_s  = extract(self.sqrt_alphas_bar, s-s, x_0.shape)
                a_ts = extract(self.sqrt_alphas_bar, s, x_0.shape)/extract(self.sqrt_alphas_bar, s-s, x_0.shape)
                sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, s-s, x_0.shape))
                sigma_t = torch.sqrt(extract(self.one_minus_alphas_bar, s, x_0.shape))
                beta_ts = sigma_t**2-a_ts**2*sigma_s**2
                x_s     =   (a_ts * x_0 + torch.sqrt(beta_ts) * noise)
            
                ###########################################
                # Generate x_t
                noise = torch.randn_like(x0).to(x0.device)
                t = x_0.new_ones([x_0.shape[0], ], dtype=torch.long)  * t_count
                a_s  = extract(self.sqrt_alphas_bar, s, x_0.shape)
                a_ts = extract(self.sqrt_alphas_bar, t, x_0.shape)/extract(self.sqrt_alphas_bar, s, x_0.shape)
                sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, s, x_0.shape))
                sigma_t = torch.sqrt(extract(self.one_minus_alphas_bar, t, x_0.shape))
                beta_ts = sigma_t**2-a_ts**2*sigma_s**2
                x_t =   ( a_ts * x_s + torch.sqrt(beta_ts) * noise)


                mean,cov,skeness,sigma2_small = self.p_mean_variance(x_t=x_t, t=t)
                pre_cov = sigma2_small
                fmoment = None
                mean1,mean2,beta = solve_gmm(mean,cov,skeness,fmoment,pre_cov,t_count)

                # Gaussian ************************
                if self.covmean:
                    cov = mean1/mean1  * cov.mean()
                #if self.sample_type=='ddpm':
                # DDPM ************************
                cov_ddpm =  self.ddpm_cov(x_t=x_t, t=t)
                minus_log_model_gaussian_ddpm  = -log_prob_gaussian(x_s,mean,cov_ddpm)

                minus_log_model_gaussian  = -log_prob_gaussian(x_s,mean,cov)
                # Ground Truth ************************
                meant = a_ts* sigma_s.pow(2)/(sigma_t.pow(2)) * x_t + a_s*beta_ts/(sigma_t.pow(2)) * x_0
                covv  = ((sigma_s.pow(2)*beta_ts)/sigma_t.pow(2))
                log_q_gaussian = log_prob_gaussian(x_s,meant,covv)
                # GMDDPM ************************
                var = pre_cov*beta
                var = mean1/mean1 * var.mean()
                minus_log_model_mgaussian = -log_prob_mixturegaussian(x_s,mean1,mean2,var,var)

                inter_kl_log_ddpm = minus_log_model_gaussian_ddpm + log_q_gaussian
                inter_kl_log_g  = minus_log_model_gaussian  + log_q_gaussian
                inter_kl_log_mg = minus_log_model_mgaussian + log_q_gaussian
                inter_kl_log_g_list.append(inter_kl_log_g.mean().cpu().detach().numpy())
                inter_kl_log_mg_list.append(inter_kl_log_mg.mean().cpu().detach().numpy())
                inter_kl_log_ddpm_list.append(inter_kl_log_ddpm.mean().cpu().detach().numpy())
                
            p1_ddpm = np.mean(inter_kl_log_ddpm_list)
            p1_gaussian  = np.mean(inter_kl_log_g_list)
            p1_mgaussian = np.mean(inter_kl_log_mg_list)

            logging.info('the likelihood for ddpm is {0} for this step'.format(p1_ddpm/math.log(2)))
            logging.info('the likelihood for gaussian is {0} for this step'.format(p1_gaussian/math.log(2)))
            logging.info('the likelihood for mgaussian is {0} for this step'.format(p1_mgaussian/math.log(2)))
            #if self.sample_type=='gmddpm':
            if p1_mgaussian<p1_gaussian:
                logging.info('Gaussian Mixture outperforms from {0} to {1}'.format(t_count,s_count))
            #    return p1_mgaussian/math.log(2.)
            else:
                p1_mgaussian = p1_gaussian
            logging.info('\n')
            return p1_ddpm/math.log(2.),p1_gaussian/math.log(2.),p1_mgaussian/math.log(2.)

       
def infiniteloop(dataloader):
    while True:
        for x, y in iter(dataloader):
            yield x

def Sample_parallel(net_sampler):
    #save_file_ddpm = './sample/likelihood/ddpm'+str(FLAGS.noise_schedule)+'_'+str(FLAGS.section_begin)+'_'+str(FLAGS.section_end)+'2.csv'
    #save_file_adpm = './sample/likelihood/adpm'+str(FLAGS.noise_schedule)+'_'+str(FLAGS.section_begin)+'_'+str(FLAGS.section_end)+'2.csv'
    #save_file_gdpm = './sample/likelihood/gdpm'+str(FLAGS.noise_schedule)+'_'+str(FLAGS.section_begin)+'_'+str(FLAGS.section_end)+'2.csv'
    save_file_ddpm = './sample/likelihood/improved'+str(FLAGS.noise_schedule)+'sn.csv'

    dataset = CIFAR10(
            root='./data', train=False, download=True,
            transform=transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]))
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=FLAGS.batch_size, shuffle=True,
        num_workers=FLAGS.num_workers, drop_last=True)
    datalooper = infiniteloop(dataloader)
    TT = 1001
    section_begin = FLAGS.section_begin
    section_end   = FLAGS.section_end
    #L = np.random.randint(0, 100, size=(1001, 1001))
    L_ddpm = np.zeros((TT, TT))+1e10
    L_adpm = np.zeros((TT, TT))+1e10
    L_gdpm = np.zeros((TT, TT))+1e10

    # (n+285)(n-285)=41682
    #L_list = []
    ans = 1
    for i in trange(0, FLAGS.num_images, FLAGS.batch_size):
        x_0 = next(datalooper).to(device).float()
        for s_count in range(0,1):
            for t_count in range(s_count,TT):
                logging.info('t-s {0}'.format(t_count-s_count))
                temp_ddpm = L_ddpm[s_count,t_count]*(ans-1)
                temp_adpm = L_adpm[s_count,t_count]*(ans-1)
                temp_gdpm = L_gdpm[s_count,t_count]*(ans-1)
                #print(net_sampler(x_0.to(device),t_count,s_count))
                L_temp_ddpm,L_temp_adpm,L_temp_gdpm = net_sampler(x_0.to(device),t_count,s_count)
                #if L_temp!= 1e10:
                #    logging.info(L_temp)
                #L_ddpm[s_count,t_count] = (temp_ddpm + L_temp_ddpm)/ans
                #L_adpm[s_count,t_count] = (temp_adpm + L_temp_adpm)/ans
                L_gdpm[s_count,t_count] = (temp_gdpm + L_temp_gdpm)/ans
                data_df = pd.DataFrame(L_gdpm)
                data_df.to_csv(save_file_ddpm)

            #data_df = pd.DataFrame(L_ddpm)
            #data_df.to_csv(save_file_ddpm)
            #data_df = pd.DataFrame(L_adpm)
            #data_df.to_csv(save_file_adpm)
            #data_df = pd.DataFrame(L_gdpm)
            #data_df.to_csv(save_file_gdpm)
        ans +=1 

def eval():
    if FLAGS.time_shift:
        if FLAGS.noise_schedule != 'cosine':
            eps1_model = UNetModel4Pretrained2(in_channels=FLAGS.in_channel,model_channels=FLAGS.ch,out_channels=FLAGS.out_channel,num_res_blocks=FLAGS.num_res_blocks,attention_resolutions=FLAGS.attn,dropout=FLAGS.dropout,
            channel_mult=FLAGS.ch_mult,conv_resample=True,dims=FLAGS.dims,num_classes=None,use_checkpoint=False,num_heads=FLAGS.num_heads,num_heads_upsample=-1,use_scale_shift_norm=FLAGS.use_scale_shift_norm,
            head_out_channels=FLAGS.head_out_channels,mode='simple')
            ckpt1 = torch.load('/home/aiops/allanguo/MixtureGaussianDiffusion/models/cifar10_ema_eps_eps2_pretrained_340000.ckpt.pth')
        else:
            eps1_model = UNetModel4Pretrained2(in_channels=FLAGS.in_channel,model_channels=FLAGS.ch,out_channels=FLAGS.out_channel,num_res_blocks=FLAGS.num_res_blocks,attention_resolutions=FLAGS.attn,dropout=FLAGS.dropout,
            channel_mult=FLAGS.ch_mult,conv_resample=True,dims=FLAGS.dims,num_classes=None,use_checkpoint=False,num_heads=FLAGS.num_heads,num_heads_upsample=-1,use_scale_shift_norm=FLAGS.use_scale_shift_norm,
            head_out_channels=FLAGS.head_out_channels,mode='simple')
            ckpt1 = torch.load('/home/aiops/allanguo/cifar/logs/cifar10_cosine1000_ema_eps_eps2_pretrained_460000.ckpt.pth')
            #ckpt1 = torch.load('./logs/iDDPM_CIFAR10_cos_EPS1/models/ckpt_1_600000.pt')['ema_model']
    else:
        if FLAGS.noise_schedule != 'cosine':
            logging.info(FLAGS.noise_schedule)
            eps1_model = UNetModel(in_channels=FLAGS.in_channel,model_channels=FLAGS.ch,out_channels=FLAGS.out_channel,num_res_blocks=FLAGS.num_res_blocks,attention_resolutions=FLAGS.attn,dropout=FLAGS.dropout,
            channel_mult=FLAGS.ch_mult,conv_resample=True,dims=FLAGS.dims,num_classes=None,use_checkpoint=False,num_heads=FLAGS.num_heads,num_heads_upsample=-1,use_scale_shift_norm=FLAGS.use_scale_shift_norm,)
            ckpt1 = torch.load('./logs/iDDPM_CIFAR10_EPS1/models/ckpt_1_800000.pt')['ema_model']
        else:
            eps1_model = UNetModel(in_channels=FLAGS.in_channel,model_channels=FLAGS.ch,out_channels=FLAGS.out_channel,num_res_blocks=FLAGS.num_res_blocks,attention_resolutions=FLAGS.attn,dropout=FLAGS.dropout,
            channel_mult=FLAGS.ch_mult,conv_resample=True,dims=FLAGS.dims,num_classes=None,use_checkpoint=False,num_heads=FLAGS.num_heads,num_heads_upsample=-1,use_scale_shift_norm=FLAGS.use_scale_shift_norm,)
            #ckpt1 = torch.load('/home/aiops/allanguo/cifar/logs/cifar10_cosine1000_ema_eps_eps2_pretrained_460000.ckpt.pth')
            ckpt1 = torch.load('./logs/iDDPM_CIFAR10_cos_EPS1/models/ckpt_1_800000.pt')['ema_model']
    eps1_model.load_state_dict(ckpt1)
    eps1_model.eval()

    # Sampling for Extended Analytic DPM
    eps2_model = UNetModel4Pretrained(in_channels=FLAGS.in_channel,model_channels=FLAGS.ch,out_channels=FLAGS.out_channel,num_res_blocks=FLAGS.num_res_blocks,attention_resolutions=FLAGS.attn,dropout=FLAGS.dropout,
        channel_mult=FLAGS.ch_mult,conv_resample=True,dims=FLAGS.dims,num_classes=None,use_checkpoint=False,num_heads=FLAGS.num_heads,num_heads_upsample=-1,use_scale_shift_norm=FLAGS.use_scale_shift_norm,
        head_out_channels=FLAGS.head_out_channels,mode='simple')

    eps2_nll_model = UNetModel4Pretrained(in_channels=FLAGS.in_channel,model_channels=FLAGS.ch,out_channels=FLAGS.out_channel,num_res_blocks=FLAGS.num_res_blocks,attention_resolutions=FLAGS.attn,dropout=FLAGS.dropout,
        channel_mult=FLAGS.ch_mult,conv_resample=True,dims=FLAGS.dims,num_classes=None,use_checkpoint=False,num_heads=FLAGS.num_heads,num_heads_upsample=-1,use_scale_shift_norm=FLAGS.use_scale_shift_norm,
        head_out_channels=FLAGS.head_out_channels,mode='simple')
    if FLAGS.noise_schedule != 'cosine':
        ckpt2 = torch.load('/home/aiops/allanguo/MixtureGaussianDiffusion/models/cifar10_ema_eps_eps2_pretrained_340000.ckpt.pth')
        ckpt2_nll = torch.load('/home/aiops/allanguo/cifar/logs/cifar10_ema_eps_epsc_pretrained_190000.ckpt.pth')
    else:
        ckpt2_nll = torch.load('/home/aiops/allanguo/cifar/logs/cifar10_cosine1000_ema_eps_epsc_pretrained_150000.ckpt.pth')
        ckpt2 = torch.load('/home/aiops/allanguo/cifar/logs/cifar10_cosine1000_ema_eps_eps2_pretrained_460000.ckpt.pth')

    eps2_model.load_state_dict(ckpt2)
    eps2_nll_model.load_state_dict(ckpt2_nll)
    eps2_model.eval()
    eps2_nll_model.eval()


    if FLAGS.noise_schedule == 'cosine':
        eps3_model = UNetModel4Pretrained(in_channels=FLAGS.in_channel,model_channels=FLAGS.ch,out_channels=FLAGS.out_channel,num_res_blocks=FLAGS.num_res_blocks,attention_resolutions=FLAGS.attn,dropout=FLAGS.dropout,
            channel_mult=FLAGS.ch_mult,conv_resample=True,dims=FLAGS.dims,num_classes=None,use_checkpoint=False,num_heads=FLAGS.num_heads,num_heads_upsample=-1,use_scale_shift_norm=FLAGS.use_scale_shift_norm,
            head_out_channels=FLAGS.head_out_channels,mode='complex2')
        ckpt3_path = './logs/iDDPM_CIFAR10_cos_EPS3_c2/models/ckpt_3_1500000.pt'
    else:
        eps3_model = UNetModel4Pretrained(in_channels=FLAGS.in_channel,model_channels=FLAGS.ch,out_channels=FLAGS.out_channel,num_res_blocks=FLAGS.num_res_blocks,attention_resolutions=FLAGS.attn,dropout=FLAGS.dropout,
            channel_mult=FLAGS.ch_mult,conv_resample=True,dims=FLAGS.dims,num_classes=None,use_checkpoint=False,num_heads=FLAGS.num_heads,num_heads_upsample=-1,use_scale_shift_norm=FLAGS.use_scale_shift_norm,
            head_out_channels=FLAGS.head_out_channels,mode='complex')
        ckpt3_path = './logs/iDDPM_CIFAR10_EPS3_2/models/ckpt_3_800000.pt'
    ckpt3 = torch.load(ckpt3_path)

    eps3_model.load_state_dict(ckpt3['ema_model'])
    eps3_model.eval()

    eps3_nll_model = UNetModel4Pretrained(in_channels=FLAGS.in_channel,model_channels=FLAGS.ch,out_channels=FLAGS.out_channel,num_res_blocks=FLAGS.num_res_blocks,attention_resolutions=FLAGS.attn,dropout=FLAGS.dropout,
            channel_mult=FLAGS.ch_mult,conv_resample=True,dims=FLAGS.dims,num_classes=None,use_checkpoint=False,num_heads=FLAGS.num_heads,num_heads_upsample=-1,use_scale_shift_norm=FLAGS.use_scale_shift_norm,
            head_out_channels=FLAGS.head_out_channels,mode='complex')
    if FLAGS.noise_schedule == 'cosine':
        eps3_nll_path = './logs/iDDPM_CIFAR10_cos_EPS3_nll/models/ckpt_3_1400000.pt'
    else:
        eps3_nll_path = './logs/iDDPM_CIFAR10_nll_EPS3/models/ckpt_3_1200000.pt'
    eps3_nll = torch.load(eps3_nll_path)
    eps3_nll_model.load_state_dict(eps3_nll['ema_model'])
    eps3_nll_model.eval()

    eps4_model= None

    net_sampler = likelihood(
        eps1_model,eps2_model,eps3_model,eps2_nll_model,eps3_nll_model,FLAGS.beta_1, FLAGS.beta_T, FLAGS.sample_steps, FLAGS.img_size,
        FLAGS.sample_type,FLAGS.time_shift,FLAGS.noise_schedule,FLAGS.covmean).to(device)
    if FLAGS.parallel:
        net_sampler = torch.nn.DataParallel(net_sampler)
    with torch.no_grad():
        Sample_parallel(net_sampler)

def main(argv):
    warnings.simplefilter(action='ignore', category=FutureWarning)
    eval()

app.run(main)



