using Distributions, ForwardDiff, LinearAlgebra, Random
using Base.Threads: @threads
using Base.Threads
using JLD, Tullio
using DataFrames, DelimitedFiles

xs = readdlm("data/dat.csv", ',', Float64; header=true)[1]
N = size(xs, 1)
d = size(xs, 2) + 1 # 19 features plus intercept 
fs = Float64.(hcat(ones(N), xs[:, 1:(d - 2)]))
rs = Float64.(xs[:, d - 1])
ϵ0 = 0.02 .* ones(d)

# standard normal (prior)
function log_prior(z)
    return -0.5 * dot(z, z) - 0.5 * d * log(2 * pi)
end

function logp_lik(β, logσ)
    diffs = rs .- fs * β
    return -0.5 * exp(-logσ) * sum(abs2, diffs) - 0.5 * N * log(2π) - 0.5 * N * logσ
end

function ∇logp(z)
    β = @view(z[1:(d - 1)])
    logσ = z[d]
    diffs = rs .- fs * β
    @tullio s[j] := diffs[i] * fs[i, j]
    gβ = -β + exp(-logσ) .* s
    # vec(sum((diffs ./ exp(z[1])) .* fs, dims=1))
    gs = -logσ - N / 2 + 0.5 * exp(-logσ) * sum(abs2, diffs)
    return vcat(gβ, [gs])
end

function logp(z)
    β = @view(z[1:(d - 1)])
    logσ = z[d]
    return log_prior(z) + logp_lik(β, logσ)
end

function logq(x, μ, D)
    return -0.5 * d * log(2π) - sum(log, abs.(D)) - 0.5 * sum(abs2, (x .- μ) ./ (D .+ 1e-8))
end
∇logq(x, μ, D) = (μ .- x) ./ (D .+ 1e-8)

if !isdir("figure")
    mkdir("figure")
end
if !isdir("result")
    mkdir("result")
end