using Plots, Random
using JLD2

include("Muscada.jl")
include("Hedge.jl")
include("bimat.jl")

# i.i.d. Rademacher
rad(S...) = sign.(rand(S...) .- .5)


# construct random matrix with attempt of at most the given magnitudes
function randmatrix(σrow, σcol)
    min.(σrow, σcol') .* rad(length(σrow), length(σcol))
end

if false
    A = randmatrix(2 .^ (1:5), 2 .^ (1:5))
elseif true
    Random.seed!(1236)
    #fac = 2 .^ (1:10) # 2:10
    fac = 100*ones(10)
    R1, R2 = 10, length(fac)
    C1, C2 = 10, length(fac)


    A = [rad(R1, C1)  -ones(R1, C2)
         ones(R2, C1) min.(fac, fac') .* rad(R2, C2)]

    # A=A' # hard game

else
    # Nice example, found by random
    A = [  2.0  -2.0  -2.0   -2.0    2.0
          -2.0  -4.0   4.0   -4.0   -4.0
           2.0  -4.0  -8.0   -8.0   -8.0
           2.0   4.0  -8.0  -16.0  -16.0
           2.0  -4.0   8.0   16.0  -32.0]
end


# solve the matrix game perfectly
val, pstar, qstar = bimat(A)
println("Saddle point value $val\npstar: $pstar\nqstar: $qstar")

# objective for iterative solvers
gap(p,q) = maximum(A'p) - minimum(A*q)

nrmz(v) = v/sum(v)  # normalized vector
unif(K) = ones(K)/K # uniform prior on K things


# create a new algorithm, with toggles for tuning 1 or 2 and optimistic yes/no
function newAlg(σs, πs, tune1, optimistic, regh)
    let σs = if optimistic 2σs else σs end # optimism doubles the range
        if regh
            K, σ = length(σs), maximum(σs)
            @assert all(πs .== 1/K)
            @assert all(σs .== σ) "$σs"
            Hedge(K, σ)
            #Hedge_η(K, sqrt(2*log(K)/T)/σ)
        else
            Muscada(σs,
                if tune1 nrmz(πs .* minimum(σs)./σs), Hℓ² else πs, H₂ end ...)
            # NB since all learning rates are equal initially (as η₀ = H(0)/2σmax with H(0) = 1), projection in D_η amounts to renormalization
        end
    end
end


function run(A, T, tune1, optimistic, rangeadaptive, regh)
    if rangeadaptive
        # actual range
        σrow = maximum.(abs, eachrow(A))
        σcol = maximum.(abs, eachcol(A))
    else
        # maximum range
        σ = maximum(abs, A)
        σrow = fill(σ, size(A,1))
        σcol = fill(σ, size(A,2))
    end

    # Create pair of learners
    P = newAlg(σrow, unif(length(σrow)), tune1, optimistic, regh)
    Q = newAlg(σcol, unif(length(σcol)), tune1, optimistic, regh)

    if !regh
        println("\nGoing with $(P.σs), $(Q.σs)")
    end

    pbar = zeros(length(σrow))
    qbar = zeros(length(σcol))

    pm = zeros(length(σrow)) # guess (previous loss for optimism)
    qm = zeros(length(σcol)) #

    # several objectives to measure here
    itgap = zeros(T)
    avgap = zeros(T)

    for t in 1:T
        if optimistic
            pt = act(P, pm)
            qt = act(Q, qm)
        else
            pt = act(P)
            qt = act(Q)
        end

        pbar .+= (pt.-pbar)./t
        qbar .+= (qt.-qbar)./t

        itgap[t] = gap(pt,   qt)
        avgap[t] = gap(pbar, qbar)

        ℓp = clamp.( A*qt, -σrow, σrow)
        ℓq = clamp.(-A'pt, -σcol, σcol)

        if optimistic
            incur!(P, ℓp, pm)
            incur!(Q, ℓq, qm)
            pm = ℓp # update future guess to current loss vector
            qm = ℓq
        else
            incur!(P, ℓp)
            incur!(Q, ℓq)
        end

        #println("t: $t  vₜ: $(P.v) and $(Q.v)\nηp: $(get_ηs(P))\nηq: $(get_ηs(Q))\nwp: $pt\nwq: $qt\nℓp: $(A*qt)\nℓq: $(-A'pt)")

    end

    if !regh
        println("v $(P.v), $(Q.v)")
        println("Rp $(maximum(P.Rs)) among $(P.Rs)")
        println("Rq $(maximum(Q.Rs)) among $(Q.Rs)")
    end

    P, Q, pbar, qbar, itgap, avgap
end


T = 25_000_000 #1_000_000; # 10_000_000
Tstep = floor.(Int64, exp.(range(log(1), log(T), length=10000)))






σrow = maximum.(eachrow(abs.(A)))
σcol = maximum.(eachcol(abs.(A)))


#plot!(itgap, label="last iterate gap")
#plot!(avgap, label="it.avg. gap")
#plot!(Oitgap, label="O. last iterate gap")
#plot!(Oavgap, label="O. it.avg. gap")


# get the results

res = []

for regh in false:true
    for tune1 in false:true
        for optimistic in false:true
            for rangeadaptive in false:true
                if !tune1 || !optimistic continue end
                if regh && rangeadaptive continue end

                P, Q, pbar, qbar, itgap, avgap = run(A, T, tune1, optimistic, rangeadaptive, regh)

                push!(res, (P, Q, pbar, qbar, itgap, avgap, tune1, optimistic, rangeadaptive, regh))
            end
        end
    end
end


jldsave("games.jld2"; res, T, Tstep, σrow, σcol, A, val, pstar, qstar)
