
import JSON

Turing.@model function radon(J, log_radon, county_idx, log_uppm, floor_measure)
    σ     ~ Turing.TruncatedNormal(0, 1, 0, Inf)
    σ_α   ~ Turing.TruncatedNormal(0, 1, 0, Inf)
    μ_α   ~ Normal(0, 10)
    β₁    ~ Normal(0, 10)
    β₂    ~ Normal(0, 10)
    α_raw ~ MvNormal(zeros(J), I)

    α = μ_α .+ α_raw*σ_α

    @inbounds μⱼ = α[county_idx] + log_uppm * β₁;
    μ  = μⱼ + floor_measure*β₂;

    log_radon ~ MvNormal(μ, σ*σ*I)
end

function radon_dataset()
    data          = JSON.parse(read(datadir("datasets", "radon_all.json"), String))
    J             = data["J"]
    log_radon     = data["log_radon"]     |> Array{Float64}
    county_idx    = data["county_idx"]    |> Array{Int}
    log_uppm      = data["log_uppm"]      |> Array{Float64}
    floor_measure = data["floor_measure"] |> Array{Int}
    J, log_radon, county_idx, log_uppm, floor_measure
end

function model_with_dataset(::Val{:radon}, dataset, batchsize; rng=Random.GLOBAL_RNG)
    J, log_radon, county_idx, log_uppm, floor_measure = radon_dataset()
    n_data    = length(log_radon)
    batchsize = min(batchsize, n_data)

    data_idx = 1:n_data
    data_itr = Iterators.partition(data_idx, batchsize)

    log_radon_batch     = log_radon[data_idx]
    county_idx_batch    = county_idx[data_idx]
    log_uppm_batch      = log_uppm[data_idx]
    floor_measure_batch = floor_measure[data_idx]

    model    = radon(J, log_radon_batch, county_idx_batch, log_uppm_batch, floor_measure_batch)
    context  = DynamicPPL.MiniBatchContext(; batch_size=batchsize, npoints=n_data)

    varinfo = DynamicPPL.VarInfo(model)
    b       = Bijectors.bijector(model)
    b⁻¹     = inverse(b)
    prob    = DynamicPPL.LogDensityFunction(model, varinfo, context)

    model_full = radon(J, log_radon, county_idx, log_uppm, floor_measure)
    prob_full  = DynamicPPL.LogDensityFunction(model_full)

    function sample_train_batch!(_prob)
        if isempty(data_itr)
            data_itr = Iterators.partition(shuffle(rng, data_idx), batchsize)
        end
        batch_idx, data_itr = Iterators.peel(data_itr)

        _prob = @set _prob.model.args.log_radon     = log_radon[    batch_idx]
        _prob = @set _prob.model.args.county_idx    = county_idx[   batch_idx]
        _prob = @set _prob.model.args.log_uppm      = log_uppm[     batch_idx]
        _prob = @set _prob.model.args.floor_measure = floor_measure[batch_idx]
        _prob = @set _prob.context.loglike_scalar   = n_data/length(batch_idx)
        _prob 
    end

    function compute_full_elbo!(q, b⁻¹, M)
        zs, ∑logdetjac = rand_and_logjac(rng, q, b⁻¹, M)

        𝔼ℓjoint = mapreduce(+, 1:M) do m
            LogDensityProblems.logdensity(prob_full, view(zs, :, m))
        end  / M

        𝔼ℓjoint + entropy(q) + ∑logdetjac/M
    end

    function validate(z)
        vi_new = DynamicPPL.unflatten(prob.varinfo, prob.context, z)
        α_raw  = vi_new.metadata.α_raw.vals
        μ_α    = vi_new.metadata.μ_α.vals[1]
        σ_α    = vi_new.metadata.σ_α.vals[1]
        β₁     = vi_new.metadata.β₁.vals[1]
        β₂     = vi_new.metadata.β₂.vals[1]
        α      = μ_α .+ α_raw*σ_α
        
        μⱼ = α[county_idx] + log_uppm*β₁;
        μ  = μⱼ + floor_measure*β₂;

        sqrt(mean((μ - log_radon).^2))
    end

    prob, b⁻¹, sample_train_batch!, compute_full_elbo!, validate
end
