
using DrWatson
@quickactivate "BBVIConvergence"

using DelimitedFiles
using Plots, StatsPlots
using Random123
using DataFrames
using UnPack

include(srcdir("BBVIConvergence.jl"))
include("utils.jl")

function kl_divergence(p::MvNormal, q::MvNormal)
    Σp = p.Σ
    Σq = q.Σ
    μp = p.μ
    μq = q.μ
    d  = length(μp)

    ((logabsdet(Σq)[1] - logabsdet(Σp)[1]) - d
    + tr(Σq \ Σp)
    + PDMats.invquad(Σq, μq - μp) )/2
end

function estimate_expected_squared_norm(
      rng, logdensityprob, λ::AbstractVector,
      M::Int, φ, ψ⁻¹, ϕ, 
      unflatten, param_type, estimator_type, ad_type,
    )
    grad_buf  = DiffResults.GradientResult(λ)
    n_samples = 128

    mapreduce(+, 1:n_samples) do _
        grad_elbo!(rng, logdensityprob, λ, M, φ, ψ⁻¹, ϕ, 
                   unflatten, param_type, estimator_type, ad_type,
                   nothing, grad_buf)
        g = DiffResults.gradient(grad_buf)
        sum(abs2, g) / n_samples
    end
end

function run_vi(m₀, C₀, m_opt, C_opt, L, γ, ϵ, T, estimator;
                rng,
                param_type                 = :squareroot,
                estimate_gradient_variance = nothing,
                early_terminate            = false)
    μ = m_opt
    Σ = C_opt*C_opt' |> Hermitian
    π = MvNormal(μ, Σ)
    Turing.@model function quadratic()
        z ~ π
    end
    model = quadratic()
    b     = Bijectors.bijector(model)
    b⁻¹   = inverse(b)
    prob  = DynamicPPL.LogDensityFunction(model)
    M     = 1

    ϕ          = identity
    optimizer  = Descent(γ)
    ad_type    = ReverseDiffAD

    _, unflatten = get_flatten_utils(Val(param_type), prob)

    t = 1
    function callback!(t, _, λ, q, _, _)
        q  = contruct_q(param_type, λ, ϕ, unflatten)

        m, C  = if param_type == :meanfield
            m, s, _  = unflatten(λ)
            m, Diagonal(s)
        elseif param_type == :squareroot
            m, _, C  = unflatten(λ)
            m, C
        end
        Δλ² = sum(abs2, m - m_opt) + sum(abs2, C - C_opt)

        stats = (kl = kl_divergence(q, π), Δλ² = Δλ²)
        stat′  = if !isnothing(estimate_gradient_variance) && mod(t, div(T,100)) == 1
            𝔼g² = estimate_expected_squared_norm(rng, prob, λ, M,
                                                 Normal(), b⁻¹, ϕ, 
                                                 unflatten, param_type,
                                                 estimator, ad_type)
            stat′ = estimate_gradient_variance(Δλ²)
            merge(stat′, (𝔼g² = 𝔼g²,))
        else
            NamedTuple()
        end
        t += 1
        merge(stats, stat′)
    end

    function svd_projection(λ, flatten, unflatten)
        if param_type == :meanfield
            m, s, _  = unflatten(λ)
            s        = max.(s, 1 ./ sqrt(L))
            flatten(m, s, nothing)
        elseif param_type == :squareroot
            m, _, C  = unflatten(λ)
            U, D, _  = svd(C)
            D        = max.(D, 1 ./ sqrt(L))
            C_proj   = U*Diagonal(D)*U' |> Hermitian
            flatten(m, nothing, C_proj)
        end
    end

    function terminate(t, λ, q, stats)
        m, C = if param_type == :meanfield
            m, s, _ = unflatten(λ)
            m, Diagonal(s)
        else
            m, _, C = unflatten(λ)
            m, C
        end
        Δλ² = sum(abs2, m - m_opt) + sum(abs2, C - C_opt)
        Δλ² ≤ ϵ && early_terminate
    end

    _, _, stats = bbvi(prob, M, T, m₀, C₀;
                       rng           = rng,
                       ψ⁻¹           = b⁻¹,
                       ϕ             = ϕ,
                       optimizer     = optimizer,
                       show_progress = true,
                       callback!     = callback!,
                       param_type    = param_type,
                       estimator_type = estimator,
                       terminate     = terminate,
                       projection    = svd_projection,
                       ad_type       = ad_type)

    stat_keys = vcat([keys(stat) |> collect for stat ∈ stats]...) |> Set |> collect
    mapreduce(merge, stat_keys) do key
        t, y = filter_stats(key,  stats)
        (Symbol("t_$(string(key))") => t,
         Symbol("$(string(key))")   => y,) |> NamedTuple
    end
end

function compute_problem_constants(μ_post, Σ_post, m, C, m₀, C₀, ϵ, estimator)
    d = length(μ_post)
    L = 1/eigmin(Σ_post)
    μ = 1/eigmax(Σ_post)
    k = 3

    if estimator isa StickingTheLanding
        fisher = sum(abs2, Σ_post\C - inv(C))
        δ      = max(fisher/ϵ, eps(typeof(ϵ)))
        α      = 4*L^2*(d + k)*(1 + δ/2)
        β      = (d + k)*(1 + 2*1/δ)*fisher

        α_δ1   = 4*L^2*(d + k)*(1 + 1/2)
        β_δ1   = (d + k)*(1 + 2)*fisher

        (α       = α,    β       = β,
         α_δ1    = α_δ1, β_δ1    = β_δ1,
         α_domke = Inf,  β_domke = Inf,
         δ       = δ,    μ       = μ,
         L       = L)
    else
        modevar = sum(abs2, μ_post - m) + sum(abs2, C)
        δ       = 2/ϵ*(d + k)/(d + k + 4)*modevar
        α       = L^2*(d + k + 4)*(1 + δ)
        β       = L^2*(d + k)*(1 + 1/δ)*modevar

        α_δ1    = L^2*(d + k + 4)*(1 + 1)
        β_δ1    = L^2*(d + k)*(1 + 1/1)*modevar

        α_domke = L^2*4*(d + 3)
        β_domke = L^2*4*(d + 3)*modevar + d*L

        (α       = α,        β       = β,
         α_δ1    = α_δ1,     β_δ1    = β_δ1,
         α_domke = α_domke,  β_domke = β_domke,
         δ       = δ,        μ       = μ,
         L       = L)
    end
end

function complexity_guarantee(problem_constants, ϵ, m, C, m₀, C₀, estimator)
    @unpack μ, α, β = problem_constants
    Δλ₀² = sum(abs2, m - m₀) + sum(abs2, C - C₀)
    γ    = min(ϵ*μ/4/β, μ/2/α, 2/μ)
    T    = max(4*β/μ^2/ϵ, 2*α/μ^2, 1/2)*log(2*Δλ₀²/ϵ)
    T, γ
end

function gradient_variance(problem_constants, Δλ², estimator)
    @unpack α, β, α_δ1, β_δ1, α_domke, β_domke = problem_constants

    if estimator isa StickingTheLanding
        (𝔼g²_theory = α*Δλ² + β,)
    else 
        (𝔼g²_theory        = α*Δλ²       + β,
         𝔼g²_theory_domke  = α_domke*Δλ² + β_domke,
         𝔼g²_theory_δ1     = α_δ1*Δλ²    + β_δ1,
         )
    end
end

function test()
    seed = (0x97dcb950eaebcfba, 0x741d36b68bef6415)
    rng  = Random123.Philox4x(UInt64, seed, 8)

    key = 1

    Random123.set_counter!(rng, key)
    Random.seed!(key)

    estimator  = StickingTheLanding{false}(); #ClosedFormEntropy{false}();
    param_type = :squareroot #:meanfield

    d  = 30
    m₀ = zeros(d)
    C₀ = Diagonal(I, d)
    T  = 10000

    μ = randn(rng, d)
    Σ = Diagonal(0.01*I, d) |> PDMats.PDMat

    m       = μ
    C       = if param_type == :meanfield 
        sqrt.(diag(Σ)) |> Diagonal
    else
        U, D, _ =  svd(Σ)
        C       = (U*Diagonal(sqrt.(D))*U') |> Hermitian
    end

    ϵ = 0.0001

    prob_consts = compute_problem_constants(μ, Σ, m, C, m₀, C₀, ϵ, estimator)
    T_pred, γ   = complexity_guarantee(prob_consts, ϵ, m, C, m₀, C₀, estimator)

    @info("Problem Stats",
          ϵ, d, γ, T_pred,
          μ = prob_consts.μ,
          L = prob_consts.L,
          α = prob_consts.α,
          β = prob_consts.β,
          δ = prob_consts.δ,
          )

    estimate_gradient_variance(Δλ²) = begin
        gradient_variance(prob_consts, Δλ², estimator) 
    end

    stats = run_vi(m₀, C₀, m, C, prob_consts.L, γ, ϵ, T, estimator;
                   rng,
                   estimate_gradient_variance,
                   early_terminate = false,
                   param_type      = param_type)
    Plots.plot!(stats[:Δλ²], yscale=:log10)
    Plots.vline!([T_pred])
    Plots.hline!([ϵ])
end


function complexity_run(est, ϵ)


    n_rep  = 5 
    T_max  = 50000

    d     = 10
    Ns    = round.(Int, 10 .^range(2, 6; length=10))
    stats = map(Ns) do N
        stats = map(1:n_rep) do key
            stats = run_setting(key, est, ϵ, d, N, T_max)

            Δλ² = stats[:Δλ²]
            Δλ² = vcat(Δλ², fill(Inf, T_max-length(Δλ²)))

            T_maybe  = findfirst(Δλ² .<= ϵ)
            T_actual = isnothing(T_maybe) ? T_max : T_maybe
            T_pred   = stats.T_pred

            (T      = T_actual,
             T_pred = T_pred,)
        end
        Ts     = [stat.T for stat ∈ stats]
        T      = mean(Ts)
        T_d    = quantile(Ts, 0.1)
        T_u    = quantile(Ts, 0.9)
        T_pred = first(stats).T_pred
        (T = T, T_d = T_d, T_u = T_u, T_pred = T_pred)
    end
    name = (est isa StickingTheLanding) ? "STL" : "CFE"

    T_m    = [stat.T      for stat ∈ stats]
    T_d    = [stat.T_d    for stat ∈ stats]
    T_u    = [stat.T_u    for stat ∈ stats]
    T_pred = [stat.T_pred for stat ∈ stats]

    Plots.plot!(Ns, T_m, ribbon=(T_m - T_d, T_u - T_m),
                xscale=:log10, yscale=:log10, label="$(name) Actual") |> display
    Plots.plot!(Ns, T_pred,
                xscale=:log10, yscale=:log10, label="$(name) Theory") |> display
    Ns, T_pred, T_m, T_d, T_u
end

function complexity()
    name = "small_epsilon"
    ϵ    = 0.0001

    Plots.plot() |> display

    est = ClosedFormEntropy{false}()
    Ns, T_pred, T_m, T_d, T_u = complexity_run(est, ϵ)

    # open(projectdir("papers/bbvi_stl/data/complexity_cfe_$(name).csv"), "w") do io
    #     write(io, "N,Tpred,Tmean,Tlo,Thi\n")
    #     data = hcat(Ns, T_pred, T_m, T_d, T_u)
    #     writedlm(io, data, ',')
    # end

    est = StickingTheLanding{false}()
    Ns, T_pred, T_m, T_d, T_u = complexity_run(est, ϵ)

    # open(projectdir("papers/bbvi_stl/data/complexity_stl_$(name).csv"), "w") do io
    #     write(io, "N,Tpred,Tmean,Tlo,Thi")
    #     data = hcat(Ns, T_pred, T_m, T_d, T_u)
    #     writedlm(io, data, ',')
    # end
end

function gradient_variance()
    key    = 1

    est   = ClosedFormEntropy{false}()
    d     = 30
    T_max = 10000
    ϵ     = 0.0001
    name  = "small_epsilon"
    N     = 10^3
    stats = run_setting(key, est, ϵ, d, N, T_max; estimate_grad_var=true)

    t_grad_var      = stats[:t_𝔼g²]
    grad_var        = stats[  :𝔼g²]
    grad_var_theory = stats[  :𝔼g²_theory]
    Plots.plot( t_grad_var, grad_var,          yscale=:log10, label="𝔼g²")        |> display
    Plots.plot!(t_grad_var, grad_var_theory,   yscale=:log10, label="𝔼g² theory") |> display
    Plots.plot!(t_grad_var, stats[:𝔼g²_domke], yscale=:log10, label="𝔼g² Domke")  |> display
    Plots.plot!(t_grad_var, stats[:𝔼g²_alt],   yscale=:log10, label="𝔼g² δ=1")    |> display

    open(projectdir("papers/bbvi_stl/data/gradient_variance_cfe_$(name).csv"), "w") do io
        write(io,
              "tgradvar,ygradvar," *
              "tadaptive,yadaptive," *
              "tdeltaone,ydeltaone," *
              "tdomke,ydomke\n")

        t_deltaone = stats[:t_𝔼g²_alt]
        y_deltaone = stats[:𝔼g²_alt]

        t_domke = stats[:t_𝔼g²_domke]
        y_domke = stats[:𝔼g²_domke]

        t_gradvar  = stats[:t_𝔼g²]
        y_gradvar  = stats[:𝔼g²]

        t_adaptive = stats[:t_𝔼g²_theory]
        y_adaptive = stats[:𝔼g²_theory]

        data = hcat(t_gradvar,  y_gradvar,
                    t_adaptive, y_adaptive,
                    t_deltaone, y_deltaone,
                    t_domke,    y_domke)
        writedlm(io, data, ',')
    end

    est   = StickingTheLanding{false}()
    stats = run_setting(key, est, ϵ, d, N, T_max; estimate_grad_var=true)

    t_grad_var      = stats[:t_𝔼g²]
    grad_var        = stats[  :𝔼g²]
    grad_var_theory = stats[  :𝔼g²_theory]
    Plots.plot( t_grad_var, grad_var,          yscale=:log10, label="𝔼g²")        |> display
    Plots.plot!(t_grad_var, grad_var_theory,   yscale=:log10, label="𝔼g² theory") |> display

    open(projectdir("papers/bbvi_stl/data/gradient_variance_stl_$(name).csv"), "w") do io
        write(io,
              "tgradvar,ygradvar,tbound,ybound\n")

        t_gradvar = stats[:t_𝔼g²]
        y_gradvar = stats[:𝔼g²]
        t_bound   = stats[:t_𝔼g²_theory]
        y_bound   = stats[:𝔼g²_theory]

        data = hcat(t_gradvar, y_gradvar,
                    t_bound,   y_bound)
        writedlm(io, data, ',')
    end
end

# function process_data(df, optimizer, param_type)
#     df = @chain df begin
#         @subset(:optimizer  .== optimizer,
#                 :param_type .== param_type)
#         @select(:logstepsize, :epsilon_T)
#     end

#     @chain groupby(df, :logstepsize) begin
#         @combine(:kl_mean   = mean(:kl),
#                  :kl_median = median(:kl),
# 		 :kl_min    = minimum(:kl),
# 		 :kl_max    = maximum(:kl),
# 		 :kl_90     = quantile(:kl, 0.9),
# 		 :kl_10     = quantile(:kl, 0.1),
# 		 )
#     end
# end
