using LinearAlgebra
include("vhat.jl")


# *** Tunings ***
# NB: these include what in the paper is called η₀ already.

function H₁(v, u, σ, σmax)
    γ = 8σmax^2 / σ^2 * log(1/u)
    (1+v/2γ)/(1+v/γ)^(3/2)/2σmax
end

function H₂(v, u, σ, σmax)
    if σ == 0 || v == 0
        1 # limit of expression below
    else
        α = 32σmax^2 / σ^2
        γ = α*log(1/u)
        z = log1p(v/α)
        num = α*z + 1/2*(2v + v^2/2γ)/(1+v/2γ)^2
        den = 2*sqrt(α^2*((1+v/α)*z - v/α) + v^2/2/(1+v/2γ))
        num/den
    end/2σmax
end

# this is the tuning that does not get luckiness perhaps,
# but still gets the (ℓ-m)² bound.
function Hℓ²(v, u, σ, σmax)
    γ = 8*log(1/u)
    (1+v/2γ)/(1+v/γ)^(3/2)/2σ
end


# *** The Algorithm ****

# This is the optimistic version, which reduces to non-optimistic when
# using zero guesses ("m") in act() and incur!().

mutable struct Muscada
    σs  # NB: range of ℓ-m (loss minus guess)
    us  # prior (and possibly some factors relating to range)
    Rs  # regret compared to each expert
    μs  # corrections
    v   # cumulative variance
    H   # learning rate tuning function

    function Muscada(σs, us, H)
        new(σs,
            us,
            zeros(size(σs)),
            zeros(size(σs)),
            0.,
            H
            )
    end
end

# Internal helper to compute the learning rates
function get_ηs(h::Muscada)
    ηs = h.H.(h.v, h.us, h.σs, maximum(h.σs))
end

# Internal helper to retrieve weights and learning rates given a guess
# (possibly zero)
function get_ηw(h::Muscada, m)
    ηs = get_ηs(h)
    ws = q̃(h.us, ηs, h.Rs .- m .- h.μs)
    ηs, ws
end

# Compute the weights to use in the upcoming round given the guess m
# (possibly zero) for the upcoming losses
function act(h::Muscada, m = 0)
    get_ηw(h, m)[2]
end

# Update based on discrepancy between incoming loss ℓs and guess m
# (possibly zero)
#
# TODO: it feels redundant to input the guess again here. Can this be
# made to be purely a function of the discrepancy ℓs.-m? Or could it
# be made to be such if act() also updated state based on the guess m
# over there?
function incur!(h::Muscada, ℓs, m = 0)
    @assert all(abs.(ℓs.-m) .≤ h.σs) "loss-guess $(ℓs.-m) not in range $(h.σs)"
    ηs, ws = get_ηw(h, m)
    Δv = newton3(ws, ℓs.-m, ηs, h.σs)
    h.Rs .+= (ws⋅ℓs) .- ℓs
    h.μs .+= ηs .* h.σs.^2 .* Δv
    h.v   += Δv
end



@testset "survive mode collapse" begin
    K = 2
    T = 10000
    h = Muscada(ones(K), ones(K)/K, H₁) # same scale

    for t in 1:T
        incur!(h, [1,-1]) # kill expert 1 in weight space
    end

    @test act(h) ≈ [0, 1] # no more weight on expert 1
    @test act(h)[1] == 0  # not just small, really zero.

    for t in 1:T
        incur!(h, [-1,1]) # equally kill expert 2
    end

    @test act(h) ≈ [1/2, 1/2] # weight came back!
end
