using StatsBase, LinearAlgebra
using .MixFlow

function inversion_err(
    o::MixFlow.HamFlow,
    a::MixFlow.HF_params,
    Ns;
    nsample::Int=100,
    res_dir="result/",
    res_name="stability.jld",
)
    ft = eltype(a.μ)
    EE_fwd = zeros(ft, size(Ns, 1) + 1, nsample)
    EE_bwd = zeros(ft, size(Ns, 1) + 1, nsample)

    for i in 1:size(Ns, 1)
        EE_fwd[i + 1, :] .= MixFlow.error_checking_fwd(o, a, Ns[i]; nsample=nsample)
        @info "$(Ns[i]) fwd done"
        # println(e)
        EE_bwd[i + 1, :] = MixFlow.error_checking_bwd(o, a, Ns[i]; nsample=nsample)
        @info "$(Ns[i]) bwd done"
        # println(e1)
    end
    return JLD.save(
        joinpath(res_dir, res_name),
        "fwd_err",
        EE_fwd,
        "bwd_err",
        EE_bwd,
        "Ns",
        vcat([0], Ns),
    )
end

# aux function for generating ribbon plot
function get_percentiles(dat; p1=25, p2=75)
    dat = Matrix(dat')
    n = size(dat, 2)

    plow = zeros(n)
    phigh = zeros(n)

    for i in 1:n
        dat_remove_nan = (dat[:, i])[iszero.(isnan.(dat[:, i]))]
        median_remove_nan = median(dat_remove_nan)
        plow[i] = median_remove_nan - percentile(vec(dat_remove_nan), p1)
        phigh[i] = percentile(vec(dat_remove_nan), p2) - median_remove_nan
    end

    return plow, phigh
end