
using DrWatson
@quickactivate "BBVIConvergence"

using DelimitedFiles
using Plots, StatsPlots
using Random123

include(srcdir("BBVIConvergence.jl"))
include("utils.jl")

function estimate_valid_elbo(rng, logdensityprob, q, b⁻¹, M,
                             sample_batch!, prepare_full_pass!)
    mapreduce(+, 1:M) do _
        ζₘ = rand(rng, q)
        zₘ = b⁻¹(ζₘ)
        n_batch = prepare_full_pass!()
        mapreduce(+, 1:n_batch) do _
            logdensityprob = sample_batch!(logdensityprob)
            LogDensityProblems.logdensity(logdensityprob, zₘ)
        end / n_batch
    end / M
end

function main()
    key  = 1
    seed = (0x97dcb950eaebcfba, 0x741d36b68bef6415)
    rng  = Random123.Philox4x(UInt64, seed, 8)
    Random123.set_counter!(rng, key)
    Random.seed!(key)

    batchsize = 100

    prob, b⁻¹, sample_train_batch!, compute_full_elbo!, _ = model_with_dataset(
        Val(:radon), :radon, batchsize; rng=rng)

    # Turing.@model function fuck()
    #     y ~ LogNormal(1, 1)
    #     z ~ MvNormal([1.0, 2.0], [0.1 0.001; 0.001 0.1])
    # end
    # model   = fuck()
    # b       = Bijectors.bijector(model)
    # b⁻¹     = inverse(b)
    # prob    = DynamicPPL.LogDensityFunction(model)
    # sample_train_batch! = nothing
    # compute_full_elbo!  = nothing

    ϕ = identity
    #ϕ = StatsFuns.softplus

    d  = LogDensityProblems.dimension(prob)
    M  = 10
    T  = 3000

    t_full = Int[]
    function callback!(t, stats, λ, q, sub_elbo, g)
        #display(Plots.plot(-hist[1:t], yscale=:log10))

        if !isnothing(compute_full_elbo!) && mod(t, 100) == 0
            full_elbo = compute_full_elbo!(q, b⁻¹, 300)
            (full_elbo=full_elbo,)
        else
            NamedTuple()
        end
    end
    #callback! = nothing

    m₀ = zeros(d)
    C₀ = Diagonal(fill(inverse(ϕ)(1.0), d))  #Matrix(1e-2*I, d, d)
    
    q, λ, stats = bbvi(prob, M, T, m₀, C₀;
                       rng           = rng,
                       ψ⁻¹           = b⁻¹,
                       ϕ             = ϕ,
                       #optimizer     = Optimisers.Descent(1e-4),
                       #optimizer     = Optimisers.Momentum(1e-3),
                       #optimizer     = Optimisers.AMSGrad(1e-2),
                       optimizer     = Optimisers.Adam(1e-2),
                       #optimizer     = Optimisers.AdaGrad(1e-1),
                       #optimizer     = ProxGenAdam(1e-2),
                       show_progress = true,
                       callback!     = callback!,
                       #param_type    = :cholesky,
                       param_type    = :meanfield,
                       #ad_type       = ForwardDiffAD,
                       #ad_type       = ZygoteAD,
                       ad_type       = ReverseDiffAD,
                       sample_batch  = sample_train_batch!)


    #flatten, unflatten = get_flatten_utils(Val(:meanfield), prob)
    #linear_response(rng, prob, λ, b⁻¹, ϕ, Normal(), 10, unflatten)

    t_sub,  elbo_sub  = filter_stats(:elbo,      stats)
    t_full, elbo_full = filter_stats(:full_elbo, stats)

    # display(Plots.plot(      t_full,      elbo_full, label="Adam (linear param.)"))

    # for (name, opt, est) ∈ [(    "Proximal", ProxGenAdam(1e-2),   ClosedFormEntropy{true}()),
    #                         (         "STL",        Adam(1e-2), StickingTheLanding{false}()),
    #                         ("Proximal STL", ProxGenAdam(1e-2),  StickingTheLanding{true}()),
    #                          ]
    #     q, stats = bbvi(prob, M, T, m₀, C₀;
    #                     rng           = rng,
    #                     ψ⁻¹           = b⁻¹,
    #                     ϕ             = ϕ,
    #                     optimizer      = opt,
    #                     show_progress  = true,
    #                     callback!      = callback!,
    #                     param_type     = :meanfield,
    #                     estimator_type = est,
    #                     #ad_type        = ForwardDiffAD,
    #                     #ad_type        = ZygoteAD,
    #                     ad_type        = ReverseDiffAD,
    #                     sample_batch   = sample_train_batch!)

    #     t, elbo = filter_stats(:full_elbo, stats)
    #     display(Plots.plot!(t, elbo, ylims=[quantile(elbo_full, 0.2), Inf], label=name, xlabel="Iteration", ylabel="ELBO"))
    # end

    #l  = @layout [a; b]
    #p1 = Plots.plot(t_full,              val_hist, ylims=[-Inf, quantile(val_hist, 0.9)])
    #p1 = Plots.plot(1:length(sub_hist),  sub_hist,  ylims=[quantile(sub_hist, 0.1), Inf])
    #p2 = Plots.plot(t_full,             full_hist,  ylims=[quantile(full_hist, 0.1), Inf])
    #display(Plots.plot(p1, p2, p3, layout = l))
    #display(Plots.plot(p1, p2, layout = l))


    #display(Plots.plot!( t_sub,  elbo_sub))
    display(Plots.plot!(t_full, elbo_full))#, ylims=[quantile(elbo_sub, 0.05), Inf]))

    #sub_hist, full_hist
    q, b⁻¹, elbo_full
end
