using Test, Base.Iterators
# compute v̂ using nested binary search (once for λ, once for v̂)


include("phi.jl")


# search for zero of increasing function
function binary_search(f, lo, hi)
    # this is supposed to catch clear ordering violations
    @assert f(lo) ≤ 1e-7 && -1e-7 ≤ f(hi) "f($lo)=$(f(lo))  f($hi)=$(f(hi))"

    for it in 1:200
        mid = (lo+hi)/2;
        if f(mid) < 0
            if lo == mid # precision exhausted
                return mid;
            end
            lo = mid;
        else
            if hi == mid # precision exhausted
                return mid;
            end
            hi = mid;
        end
    end
    @warn "iteration budget exhausted, $lo $hi" maxlog=1
    return (lo+hi)/2;
end

function bs_λ(qs, ηs, xs)
    @assert sum(qs) ≈ 1

    # lower bound for λ based on convexity
    λlbd = sum(prod, zip(qs, ηs, xs)) / sum(prod, zip(qs, ηs));

    # brutal upper bound
    λubd = maximum(xs);

    binary_search(λ -> -sum(zip(qs, ηs, xs)) do (q, η, x)
                  q*expm1(η*(x-λ))
                  end, λlbd - 1e-10, λubd + 1e-10)
end


# Find λ such that \sum_i us_i e^{η(Xs-λ)} = 1
#
# We solve the numerically stable (and still convex) equation
# \log(\sum_i us_i e^{η(Xs-λ)}) = 0
# using Newton's method. This is globally convergent

function λ(us, ηs, xs)
    lus = log.(us) .+ ηs.*xs # precompute

    # initialise a little too small (this makes the maximum term 1)
    λ = maximum(lus ./ ηs);

    for it in countfrom(1)
        ev = lus .- ηs.*λ
        m = maximum(ev)
        ws = exp.(ev .- m)
        Z = sum(ws)
        f = m + log(Z) # constraint value
        d = - ηs'ws / Z # gradient (which is < 0)
        oldλ = λ
        λ -= f/d # Newton step
        if λ ≤ oldλ
            # no further progress representable
            return oldλ
        end
        if it == 30
            # pretty arbitrary limit for giving warning
            @warn "Many Newton iterations for λ. Wazaap? $f"
            return λ
        end
    end
end

@testset "$fλ" for fλ in (bs_λ, λ)
    for K in 2:15
        for it in 1:20
            us = diff([0, sort(rand(K-1))..., 1])
            xs = 100 .* randn(K)
            ηs = rand(K)
            ntλ = fλ(us, ηs, xs)
            @test sum(us.*exp.(ηs.*(xs .- ntλ))) ≈ 1
        end
    end
end




function Φ(qs, ηs, xs)
    # inf_λ  λ + \sumᵢ qᵢ (e^(ηᵢ (xᵢ - λ))-1)/ηᵢ
    let λ = λ(qs, ηs, xs)
        λ + sum(zip(qs, ηs, xs)) do (q, η, x)
            q*expm1(η*(x-λ))/η
        end
    end
end

function q̃(qs, ηs, xs)
    # inf_λ  λ + \sumᵢ qᵢ (e^(ηᵢ (xᵢ - λ))-1)/ηᵢ
    let λ = λ(qs, ηs, xs)
        qs.*(exp.(ηs.*(xs .- λ)))
    end
end

# upper bound on ̂v based on range information
# note, this can blow up!
function v̂ubd(ηs, σs)
    b = maximum(prod, zip(ηs, σs))
    (-log1p(-b)-b)/b^2
end


function update(qs, ℓs, ηs, σs; debug=false)
    @assert all(ηs .≥ 0)
    @assert sum(qs) ≈ 1
    @assert all(-σs .≤ ℓs .≤ σs)

    l̂ = qs'ℓs;
    # search for the v̂ such that
    # Φₜ = inf_λ  λ + \sumᵢ qᵢ (e^(ηᵢ (rᵢ - λ - ηᵢ*σᵢ²*v̂)-1)/ηᵢ

    X(v̂) = l̂ .- ℓs .- ηs .* σs.^2 .* v̂
    #X(v̂) = (l̂ - ℓ - η*σ^2*v̂ for (ℓ, η, σ) in zip(ℓs, ηs, σs))

    ans = binary_search(v̂ -> -Φ(qs, ηs, X(v̂)),
                  0 - 1e-7,
                  v̂ubd(ηs, σs) + 1e-7)

    if debug
        v̂s = range(0 - 1e-7, v̂ubd(ηs, σs) + 1e-7, length=2000)
        plot(v̂s, v̂ -> -Φ(qs, ηs, X(v̂)))
        vline!([ans])
        gui()
    end

    ans

end


@testset "̂vubd" begin
    γ = 1 - 1e-7; # largest tolerable "b"
    for K in 2:5
        for i in 1:100
            σs = rand(K);
            ηs = γ * σs .* rand(K);
            ℓs = (2 .* rand(K) .- 1) .* σs;
            qs = diff([0, sort(rand(K-1))..., 1]);
            rs = qs'ℓs .- ℓs;

            @test Φ(qs, ηs, rs .- ηs .* σs.^2 .* v̂ubd(ηs, σs)) ≤ 1e-7
        end
    end
end


# TODO: maintaining weights incrementally runs the risk of them undeflowing to zero.
# Idea: instead maintain the sum of the losses and variances (neither of which overflows)
function _newton(qs, ℓs, ηs, σs)
    Z = sum(qs .* ηs .* σs.^2) # natural scale of the "v" constraint

    # starting iterate is solution to second-order expansion of problem around ηs = 0
    λ = sum(qs .* ηs .* ℓs) / sum(qs .* ηs)
    v = sum(qs .* ηs .* (λ .- ℓs).^2) / Z

    for it in countfrom()
        # function evaluation and derivatives. Argument order is (v, λ)
        wmq = qs .* expm1.(ηs .* (λ .- ℓs .- ηs .* σs.^2 .* v))
        w = wmq .+ qs
        #f00 = qs'ℓs - λ + sum(wmq ./ ηs)
        f00 = -v*Z + sum(zip(qs, ηs, ℓs, σs)) do (q, η, ℓ, σ)
            let x = λ - ℓ - η*σ^2*v
                q*η*x^2*tlrϕ(η*x)
            end
        end
        f01 = sum(wmq)

        if abs(f00)/Z + abs(f01) < 1e-15
            return v, qs'ℓs - λ
        end

        f10 = - sum(w .* ηs .* σs.^2)
        f02 = + sum(w .* ηs)
        f11 = - sum(w .* ηs.^2 .* σs.^2)

        # Newton step: v, λ = [v; λ] .- [f10 f01; f11 f02] \ [f00; f01]
        det = f10*f02 - f01*f11

        v -= (f00*f02 - f01^2)/det
        λ -= (f01*f10 - f00*f11)/det

        # println("$it  v $v  λ $λ  f00/Z $(f00/Z)  f01 $f01  det $det")

        @assert it < 30 "Newton not converging"
    end
end

const newton = first ∘ _newton


# solve for λ, v̂ in
# 0 = -1 + \sum_i w_i  e^{η_i (r_i - λ) - η_i^2 σ_i^2 v̂}
# 0 =  λ + \sum_i w_i (e^{η_i (r_i - λ) - η_i^2 σ_i^2 v̂}-1)/η_i
# where r_i = w'ℓ - ℓ_i
#
# we instead solve the numerically stabler jointly convex system
# 0 = \ln(\sum_i w_i e^{η_i (r_i - λ) - η_i^2 σ_i^2 v̂})
# 0 = \ln(\sum_i w_i e^{η_i (r_i - λ) - η_i^2 σ_i^2 v̂}/η_i) - \ln(\sum_i w_i/η_i - λ)
#
#
# TODO: We're still in trouble once some ws have underflowed to zero.
# we can never get their contributions back in this way.
# Proposal: work with explicit equality to the previous Φ_{t-1}.
#
# TODO: More generally, why should Newton globally converge here?

function newton2(ws, ℓs, ηs, σs)
    @assert sum(ws) ≈ 1 # weights for prev. Φ

    r = ws'ℓs .- ℓs
    lnw  = log.(ws)
    lnwη = log.(ws./ηs)
    Q = sum(ws./ηs)

    # value and Jacobian
    fJ(λv̂) = let (λ,v̂) = λv̂,
        xs = ηs .* (r .- λ) - ηs.^2 .* σs.^2 .* v̂,
        am = maximum(lnw  .+ xs), as = exp.(lnw  .+ xs .- am), A = sum(as),
        bm = maximum(lnwη .+ xs), bs = exp.(lnwη .+ xs .- bm), B = sum(bs)
        @assert v̂ ≥ 0 "negative v̂ $v̂"
        [am + log(A),
         bm + log(B) - log(Q - λ)],
        [-sum(ηs.*as)/A            -sum(ηs.^2 .* σs.^2 .* as)/A;
         -sum(ηs.*bs)/B+1/(Q - λ)  -sum(ηs.^2 .* σs.^2 .* bs)/B
         ]
    end

    λv̂ = [Q-1, 0.]
    for it in 1:20
        # TODO: when to stop? Small det(J) is dangerous, by why would it hold at the solution? Small Newton.decr? Small violation norm (which one)?
        f, J = fJ(λv̂)
        λv̂ .-=  J\f # Newton step
        println("$it  $λv̂  $f $(sqrt(f'f))");
    end

    λv̂[2]
end







# Find Δv such that λ(Δv) + \sum_i us_i (e^{η(Xs-λ(Δv)) .- η²σ² v̂} - 1)/η = 0
#
# Note that the left hand side is convex in Δv. We solve the equation
# using Newton's method. This is globally convergent.

function _newton3(ws, ℓs, ηs, σs)
    @assert sum(ws) ≈ 1 # weights for prev. Φ
    @assert all(ws .≥ 0)

    rs = ws'ℓs .- ℓs

    # initialise a little too small
    Δv = 0. # TODO: this may not be too small if we're working in the optimistic case, then Δv < 0 is the norm.
    for it in countfrom(1)
        let X = rs .- ηs .* σs.^2 .* Δv,
            λ = λ(ws, ηs, X),
            #Φ = λ + sum(zip(ws, ηs, X)) do (w, η, x)
            #    w*expm1(η*(x-λ))/η
            #end
            Φ = sum(zip(ws, ηs, rs, σs)) do (w, η, r, σ)
                let q = r - λ - η*σ^2*Δv
                    w*η*(q^2*tlrϕ(η*q) - σ^2*Δv)
                end
            end

            dΦdv = - sum(zip(ws, ηs, σs, X)) do (w, η, σ, x)
                w*η*σ^2*exp(η*(x-λ))
            end

            oldΔv = Δv
            Δv -= Φ/dΦdv # Newton step
            if Φ < 1e-15 || Δv ≤ oldΔv
                # no further progress representable
                return oldΔv, λ
            end
            if it == 30 # pretty arbitrary limit for giving warning
                @warn "Many Newton iterations for Δv. Wazaap? $Φ"
                return oldΔv, λ
            end
        end
    end
end

const newton3 = first ∘ _newton3


@testset "$fv" for fv in (update, newton, newton3)
    for K in 2:15
        for it in 1:20
            us = diff([0, sort(rand(K-1))..., 1])
            xs = 100 .* randn(K)
            ηs = rand(K)
            qs = q̃(us, ηs, xs)
            σs = range(0,1,K)
            ℓs = clamp.(randn(K), -σs, σs)

            v = fv(qs, ℓs, ηs, σs)
            @test Φ(qs, ηs, qs'ℓs .- ℓs .- ηs .* σs.^2 .* v)  ≈ 0  atol=1e-7
        end
    end
end
