
using DrWatson
@quickactivate "BBVIConvergence"

using DataFrames
using DataFramesMeta
using JLD2
using Plots, StatsPlots
using Statistics
using HDF5
using ProgressMeter
using UnPack

function load_data(path)
    fnames = readdir(path, join=true)
    fnames = filter(fname -> occursin("jld2", fname), fnames)
    @info("found $(length(fnames)) JLD2 files")
    dfs    = map(fnames) do fname
	    JLD2.load(fname, "data")
    end
    @info("loaded dataframes")
    vcat(dfs...)
end

#=
    julia> names(df)
    9-element Vector{String}:
     "t"
     "elbo_minibatch"
     "elbo"
     "logstepsize"
     "param_type"
     "optimizer"
     "problem"
     "dataset"
     "covariance_type"
=##

function iteration_slice(df, iteration; 
                         optimizer, 
			 dataset, 
			 problem,
			 param_type, 
			 covariance_type)
    @chain df begin
	@subset(:dataset         .== dataset,
                :problem         .== problem,
		:optimizer       .== optimizer,
		:param_type      .== param_type,
		:covariance_type .== covariance_type,
		:t               .== iteration,
                .!ismissing.(:elbo),
		isfinite.(:elbo))
	@select(:elbo, :logstepsize)
    end
end

function stepsize_slice(df, logstepsize; 
                        optimizer, 
			dataset, 
			problem,
			param_type, 
			covariance_type)
    @chain df begin
        @subset(:dataset         .== dataset,
                :problem         .== problem,
                :optimizer       .== optimizer,
                :param_type      .== param_type,
                :covariance_type .== covariance_type,
                :logstepsize     .== logstepsize,
                .!ismissing.(:elbo),
		isfinite.(:elbo))
        @select(:elbo, :t)
    end
end

function statistics(df, group_key)
    df = @chain groupby(df, group_key) begin
        @combine(:elbo_mean   = mean(:elbo),
                 :elbo_median = median(:elbo),
		 :elbo_min    = minimum(:elbo),
		 :elbo_max    = maximum(:elbo),
		 :elbo_90     = quantile(:elbo, 0.9),
		 :elbo_10     = quantile(:elbo, 0.1),
		 )
    end
end

function plot_envelope(df, iteration; 
                       optimizer, 
                       dataset, 
                       problem,
                       param_type, 
                       covariance_type)
    df = iteration_slice(df, iteration;
                         optimizer       = optimizer,
                         dataset         = dataset,
                         problem         = problem,
                         param_type      = param_type,
                         covariance_type = covariance_type)

    x = df[:,:logstepsize] |> Array{Float64}
    y = df[:,:elbo]        |> Array{Float64}

    df_stats = statistics(df, :logstepsize)
    x   = 10.0.^(df_stats[:,:logstepsize])
    y   = df_stats[:,:elbo_median]
    y_p = abs.(df_stats[:,:elbo_90] - y)
    y_m = abs.(df_stats[:,:elbo_10] - y)
    display(Plots.plot!(x, y, xscale=:log10, ylims=(quantile(y, 0.5), Inf), ribbon=(y_m, y_p)))
    x, y, y_p, y_m
end

function plot_losscurve(df, logstepsize; 
                       optimizer, 
                       dataset, 
                       problem,
                       param_type, 
                       covariance_type)
    df = stepsize_slice(df, logstepsize;
                        optimizer       = optimizer,
                        dataset         = dataset,
                        problem         = problem,
                        param_type      = param_type,
                        covariance_type = covariance_type)

    x = df[:,:t]    |> Array{Float64}
    y = df[:,:elbo] |> Array{Float64}

    df_stats = statistics(df, :t)
    x   = df_stats[:,:t]           |> Array{Int}
    y   = df_stats[:,:elbo_median] |> Array{Float64}
    y_p = abs.(df_stats[:,:elbo_90] - y)
    y_m = abs.(df_stats[:,:elbo_10] - y)
    display(Plots.plot!(x, y, xscale=:log10, ribbon=(y_m, y_p)))
    x, y, y_p, y_m
end

function export_envelope(df, io=nothing; 
		         iteration_range,
                         dataset, 
                         problem,
                         covariance_type)

    configs = [(opt=:adam,          param=:linear),
               (opt=:proximal_adam, param=:linear),
               (opt=:adam,          param=:nonlinear)]

    for iteration in iteration_range
        display(Plots.plot())
        for config in configs
            @unpack opt, param = config
            x, y, y_p_abs, y_m_abs = plot_envelope(
                df, iteration;
                optimizer       = opt,
                dataset         = dataset,
                problem         = problem,
                param_type      = param,
                covariance_type = covariance_type)
            if !isnothing(io)
                name = "$(string(opt))_$(string(param))_$(string(iteration))"
                write(io, name*"_x", x)
                write(io, name*"_y", hcat(y, y_p_abs, y_m_abs)' |> Array)
            end
        end
    end
end

function export_envelopes(df = load_data(datadir("experiment")))
    configs = [(problem=:linearreg,    dataset=:keggundirected, covariance_type=:meanfield, iteration_range=[1000, 5_000, 10_000, 50_000, 100_000]), 
	       (problem=:linearreg,    dataset=:song,           covariance_type=:meanfield, iteration_range=[1000, 5_000, 10_000, 50_000, 100_000]),
               (problem=:linearreg,    dataset=:buzz,           covariance_type=:meanfield, iteration_range=[1000, 5_000, 10_000, 50_000, 100_000]),
               (problem=:linearreg,    dataset=:houseelectric,  covariance_type=:meanfield, iteration_range=[1000, 5_000, 10_000, 50_000, 100_000]),
               (problem=:radon,        dataset=:radon,          covariance_type=:meanfield, iteration_range=[1000, 5_000, 10_000, 50_000, 100_000]),
	       (problem=:election,     dataset=:election,       covariance_type=:meanfield, iteration_range=[1000, 5_000, 10_000, 50_000, 100_000]),
               (problem=:bradleyterry, dataset=:tennis,         covariance_type=:meanfield, iteration_range=[1000, 5_000, 10_000, 50_000]),
               (problem=:autoregress,  dataset=:eeg,            covariance_type=:meanfield, iteration_range=[1000, 5_000, 10_000, 50_000, 100_000]),

               (problem=:linearreg,    dataset=:keggundirected, covariance_type=:cholesky, iteration_range=[1000, 5_000, 10_000, 50_000, 100_000]), 
               (problem=:linearreg,    dataset=:song,           covariance_type=:cholesky, iteration_range=[1000, 5_000, 10_000, 50_000, 100_000]),
               (problem=:linearreg,    dataset=:buzz,           covariance_type=:cholesky, iteration_range=[1000, 5_000, 10_000, 50_000, 100_000]),
               #(problem=:linearreg,    dataset=:houseelectric,  covariance_type=:cholesky, iteration_range=[1000, 5_000, 10_000, 50_000, 100_000]),
               (problem=:election,     dataset=:election,       covariance_type=:cholesky, iteration_range=[1000, 5_000, 10_000, 50_000, 100_000]),
               (problem=:autoregress,  dataset=:eeg,            covariance_type=:cholesky, iteration_range=[1000, 5_000, 10_000, 50_000, 100_000]),
              ]

    @showprogress for config in configs
	@info("", config...)
        @unpack problem, dataset, covariance_type, iteration_range = config
	h5open(datadir("processed", "envelope_"*savename(config)*".h5"), "w") do io
            export_envelope(df, io; 
                            iteration_range,
                            dataset, 
                            problem,
                            covariance_type)
        end
    end
end

function export_losscurve(df, io=nothing; 
		          stepsize_range,
                          dataset, 
                          problem,
                          covariance_type)

    configs = [(opt=:adam,          param=:linear),
               (opt=:proximal_adam, param=:linear),
               (opt=:adam,          param=:nonlinear)]

    for stepsize in stepsize_range
        display(Plots.plot())
        for config in configs
            @unpack opt, param = config
            x, y, y_p_abs, y_m_abs = plot_losscurve(
                df, stepsize;
                optimizer       = opt,
                dataset         = dataset,
                problem         = problem,
                param_type      = param,
                covariance_type = covariance_type)
            if !isnothing(io)
                name = "$(string(opt))_$(string(param))_$(string(stepsize))"
                write(io, name*"_x", x)
                write(io, name*"_y", hcat(y, y_p_abs, y_m_abs)' |> Array)
            end
        end
    end
end

function export_losscurves(df = load_data(datadir("experiment")))
    configs = [(problem=:linearreg,    dataset=:keggundirected, covariance_type=:meanfield, logstepsize_range=[-4, -3.5, -3, -2.5, -2.]), 
	       (problem=:linearreg,    dataset=:song,           covariance_type=:meanfield, logstepsize_range=[-4, -3.5, -3, -2.5, -2.]),
               (problem=:linearreg,    dataset=:buzz,           covariance_type=:meanfield, logstepsize_range=[-4, -3.5, -3, -2.5, -2.]),
               (problem=:linearreg,    dataset=:houseelectric,  covariance_type=:meanfield, logstepsize_range=[-4, -3.5, -3, -2.5, -2.]),
               (problem=:radon,        dataset=:radon,          covariance_type=:meanfield, logstepsize_range=[-4, -3.5, -3, -2.5, -2.]),
	       (problem=:election,     dataset=:election,       covariance_type=:meanfield, logstepsize_range=[-4, -3.5, -3, -2.5, -2.]),
	       (problem=:bradleyterry, dataset=:tennis,         covariance_type=:meanfield, logstepsize_range=[-4, -3.5, -3, -2.5, -2.]),
               (problem=:autoregress,  dataset=:eeg,            covariance_type=:meanfield, logstepsize_range=[-4, -3.5, -3, -2.5, -2.]),

               (problem=:linearreg,    dataset=:keggundirected, covariance_type=:cholesky, logstepsize_range=[-4, -3.5, -3, -2.5, -2.]), 
               #(problem=:linearreg,    dataset=:song,           covariance_type=:cholesky, logstepsize_range=[-4, -3.5, -3, -2.5, -2.]),
               (problem=:linearreg,    dataset=:buzz,           covariance_type=:cholesky, logstepsize_range=[-4, -3.5, -3, -2.5, -2.]),
               (problem=:election,     dataset=:election,       covariance_type=:cholesky, logstepsize_range=[-4, -3.5, -3, -2.5, -2.]),
               (problem=:autoregress,  dataset=:eeg,            covariance_type=:cholesky, logstepsize_range=[-4, -3.5, -3, -2.5, -2.]),
              ]

    @showprogress for config in configs
	@info("", config...)
        @unpack problem, dataset, covariance_type, logstepsize_range = config
	h5open(datadir("processed", "losscurve_"*savename(config)*".h5"), "w") do io
            export_losscurve(df, io; 
                            stepsize_range=logstepsize_range,
                            dataset, 
                            problem,
                            covariance_type)
        end
    end
end
