include("model.jl")
include("../../inference/MixFlow/MixFlow.jl")
include("../common/plotting.jl")
include("../common/result.jl")
using LogExpFunctions, ProgressMeter
using JLD

n_lfrg = 200
o = MixFlow.HamFlow(
    d,
    n_lfrg,
    logp,
    ∇logp,
    randn,
    logq,
    randn,
    MixFlow.lpdf_normal,
    MixFlow.∇lpdf_normal,
    MixFlow.cdf_normal,
    MixFlow.invcdf_normal,
    MixFlow.pdf_normal,
)

MF = JLD.load("result/mfvi.jld")
μ, D = MF["μ"], MF["D"]
a = MixFlow.HF_params(0.02 * ones(d), μ, D)
ϵ = a.leapfrog_stepsize

setprecision(BigFloat, 2048)
ft = BigFloat
a_big = MixFlow.HF_params(ft(0.02) * ones(d), ft.(μ), ft.(D))
