
include("turing.jl")

using BenchmarkTools
using DataFrames
using JLD2

function prepare_problem(config)
    @unpack bench, problem, dataset, param_type, M, B, ad_type = config
    rng = Random.GLOBAL_RNG
    batchsize = B

    model, context, sample_batch!, prepare_full_pass! = model_with_dataset(
        Val(problem), dataset, batchsize; rng=rng)

    ϕ = StatsFuns.softplus

    varinfo        = DynamicPPL.VarInfo(model)
    logprobdensity = DynamicPPL.LogDensityFunction(model, varinfo, context)
    b              = Bijectors.bijector(model)
    b⁻¹            = inverse(b)
    d              = LogDensityProblems.dimension(logprobdensity)

    flatten, unflatten = get_flatten_utils(Val(param_type), logprobdensity)

    m₀ = randn(rng, d)*0.1
    C₀ = Diagonal(fill(1e-2, d))  #Matrix(1e-2*I, d, d)
    λ  = flatten(m₀, C₀)

    ψ⁻¹ = b⁻¹
    ϕ   = ϕ
    φ   = Normal()
    logprobdensity, d, λ, ψ⁻¹, ϕ, φ, unflatten, sample_batch!, prepare_full_pass!
end

function run_grad_benchmark(config)
    @unpack bench, problem, dataset, param_type, M, B, ad_type = config
    rng = Random.GLOBAL_RNG
    batchsize = B

    logprobdensity, d, λ, ψ⁻¹, ϕ, φ, unflatten, sample_batch!, _ = prepare_problem(config)

    grad_buf = DiffResults.GradientResult(λ)

    ad_type = if ad_type == :forwarddiff
        ForwardDiffAD
    else
        ZygoteAD
    end

    grad_elbo!(rng, logprobdensity, λ, d, M, φ, ψ⁻¹, ϕ,
               unflatten, param_type, ad_type, sample_batch!, grad_buf)

    @benchmark grad_elbo!($rng, $logprobdensity, $λ, $d, $M, $φ, $ψ⁻¹, $ϕ,
                          $unflatten, $param_type, $ad_type, $sample_batch!, $grad_buf)
end

function run_elbo_benchmark(config)
    @unpack bench, problem, dataset, param_type, M, B = config
    rng = Random.GLOBAL_RNG
    batchsize = B

    logprobdensity, _, λ, ψ⁻¹, ϕ, _, unflatten, _, prepare_full_pass! = prepare_problem(config)

    q = contruct_q(param_type, λ, ϕ, unflatten)
    estimate_fulldata_elbo(rng, logprobdensity, q, ψ⁻¹, M, prepare_full_pass!)

    @benchmark estimate_fulldata_elbo($rng, $logprobdensity, $q, $ψ⁻¹, $M, $prepare_full_pass!)
end

function format_benchmark(trial)
    (mean_time   = mean(trial)    |> time |> BenchmarkTools.prettytime,
     median_time = median(trial)  |> time |> BenchmarkTools.prettytime,
     min_time    = minimum(trial) |> time |> BenchmarkTools.prettytime,
     max_time    = maximum(trial) |> time |> BenchmarkTools.prettytime,
     #
     mean_memory   = mean(trial)    |> memory |> BenchmarkTools.prettymemory,
     median_memory = median(trial)  |> memory |> BenchmarkTools.prettymemory,
     min_memory    = minimum(trial) |> memory |> BenchmarkTools.prettymemory,
     max_memory    = maximum(trial) |> memory |> BenchmarkTools.prettymemory,
     )
end

function main()
    base_config = (bench      = "grad",
                   param_type = :meanfield,
                   M          = 10,)

    for B ∈ [10, 100, 1000]
    for ad_type ∈ [:forwarddiff, :zygote]
    for prob_config ∈ [(problem    = :autoregress,
                        dataset    = :eeg,
                        B          = B,
                        ad_type    = ad_type),
                       #
                       (problem    = :linearreg,
                        dataset    = :gas,
                        B          = B,
                        ad_type    = ad_type),
                       #
                       (problem    = :linearreg,
                        dataset    = :kin40k,
                        B          = B,
                        ad_type    = ad_type),
                       #
                       (problem    = :bnn,
                        dataset    = :gas,
                        B          = B,
                        ad_type    = ad_type),
                       #
                       (problem    = :bnn,
                        dataset    = :kin40k,
                        B          = B,
                        ad_type    = ad_type),
                       #
                       (problem    = :nnmf,
                        dataset    = :movielens,
                        B          = B,
                        ad_type    = ad_type),
                       ]
        config = merge(base_config, prob_config)
        DrWatson.produce_or_load(datadir("profile"), config) do config
            trial  = run_grad_benchmark(config)
            result = Dict(pairs(format_benchmark(trial)))
            merge(result, Dict(pairs(config)))
        end
    end
    end
    end

                       
    for prob_config ∈ [(problem    = :horseshoe,
                        dataset    = :prostate,
                        B          = 1,
                        ad_type    = :zygote),
                       #
                       (problem    = :horseshoe,
                        dataset    = :prostate,
                        B          = 10,
		       ad_type    = :zygote)]
        config = merge(base_config, prob_config)
        DrWatson.produce_or_load(datadir("profile"), config) do config
            trial  = run_grad_benchmark(config)
            result = Dict(pairs(format_benchmark(trial)))
            merge(result, Dict(pairs(config)))
        end
    end


    base_config = (bench      = "grad",
                   param_type = :cholesky,
                   M          = 10,)
    for B ∈ [10, 100, 1000]
    for ad_type ∈ [:forwarddiff, :zygote]
    for prob_config ∈ [(problem    = :autoregress,
                        dataset    = :eeg,
                        B          = B,
                        ad_type    = ad_type),
                       #
                       (problem    = :linearreg,
                        dataset    = :gas,
                        B          = B,
                        ad_type    = ad_type),
                       #
                       (problem    = :linearreg,
                        dataset    = :kin40k,
                        B          = B,
                        ad_type    = ad_type),
                       ]
        config = merge(base_config, prob_config)
        DrWatson.produce_or_load(datadir("profile"), config) do config
            trial  = run_grad_benchmark(config)
            result = Dict(pairs(format_benchmark(trial)))
            merge(result, Dict(pairs(config)))
        end
    end
    end
    end

    # base_config = (bench      = "elbo",
    #                param_type = :meanfield,
    #                M          = 100)
    # for prob_config ∈ [(problem    = :autoregress,
    #                     dataset    = :eeg),
    #                    #
    #                    (problem    = :linearreg,
    #                     dataset    = NaN),
    #                    #
    #                    (problem    = :bnn,
    #                     dataset    = NaN),
    #                    #
    #                    (problem    = :horseshoe,
    #                     dataset    = NaN),
    #                    ]
    #     config = merge(base_config, prob_config)
    #     DrWatson.produce_or_load(datadir("profile"), config) do config
    #         trial  = run_elbo_benchmark(config)
    #         Dict(pairs(format_benchmark(trial)))
    #     end
    # end
end
