
Turing.@model function election(n_age, n_age_edu, n_edu, n_region, n_state,
                                age, age_edu, black, edu, female,
                                region, state, v_prev, y)
    σ_age     ~ Turing.TruncatedNormal(0, 100, 0, Inf)
    σ_edu     ~ Turing.TruncatedNormal(0, 100, 0, Inf)
    σ_age_edu ~ Turing.TruncatedNormal(0, 100, 0, Inf)
    σ_state   ~ Turing.TruncatedNormal(0, 100, 0, Inf)
    σ_region  ~ Turing.TruncatedNormal(0, 100, 0, Inf)

    b_age     ~ MvNormal(zeros(n_age),     σ_age^2*I)
    b_edu     ~ MvNormal(zeros(n_edu),     σ_edu^2*I)
    b_age_edu ~ MvNormal(zeros(n_age_edu), σ_age_edu^2*I)
    b_state   ~ MvNormal(zeros(n_state),   σ_state^2*I)
    b_region  ~ MvNormal(zeros(n_region),  σ_region^2*I)
    β         ~ MvNormal(zeros(5), 100^2*I)

    y_hat = @. β[1] + β[2]*black + β[3]*female + β[5]*female*black + β[4]*v_prev +
        b_age[age] + b_edu[edu] + b_age_edu[age_edu] + b_state[state] + b_region[region];

    @. y ~ Turing.BernoulliLogit(y_hat)
end

function election_dataset()
    data = JSON.parse(read(datadir("datasets", "election88.json"), String))
    data["N"]             = Int(data["N"])
    data["n_age"]         = Int(data["n_age"])
    data["n_age_edu"]     = Int(data["n_age_edu"])
    data["n_edu"]         = Int(data["n_edu"])
    data["n_region_full"] = Int(data["n_region_full"])
    data["n_state"]       = Int(data["n_state"])

    data["age"]         = data["age"]         |> Array{Int}
    data["age_edu"]     = data["age_edu"]     |> Array{Int}
    data["black"]       = data["black"]       |> Array{Int}
    data["edu"]         = data["edu"]         |> Array{Int}
    data["female"]      = data["female"]      |> Array{Int}
    data["region_full"] = data["region_full"] |> Array{Int}
    data["state"]       = data["state"]       |> Array{Int}
    data["v_prev_full"] = data["v_prev_full"] |> Array{Float64}
    data["y"]           = data["y"]           |> Array{Bool}
    data
end

function model_with_dataset(::Val{:election}, dataset, n_batch; rng=Random.GLOBAL_RNG)
    data    = election_dataset()
    n_data  = data["N"]
    n_batch = min(n_batch, n_data)

    N             = data["N"]
    n_age         = data["n_age"]
    n_age_edu     = data["n_age_edu"]
    n_edu         = data["n_edu"]
    n_region_full = data["n_region_full"]
    n_state       = data["n_state"]

    age         = data["age"]
    age_edu     = data["age_edu"]
    black       = data["black"]
    edu         = data["edu"]
    female      = data["female"]
    region_full = data["region_full"]
    state       = data["state"]
    v_prev_full = data["v_prev_full"]
    y           = data["y"]

    data_idx = 1:N
    data_itr = Iterators.partition(data_idx, n_batch)

    model    = election(n_age, n_age_edu, n_edu, n_region_full, n_state,
                        age, age_edu, black, edu, female,
                        region_full, state, v_prev_full, y)
    context  = DynamicPPL.MiniBatchContext(; batch_size=n_batch, npoints=n_data)

    varinfo = DynamicPPL.VarInfo(model)
    b       = Bijectors.bijector(model)
    b⁻¹     = inverse(b)
    prob    = DynamicPPL.LogDensityFunction(model, varinfo, context)

    model_full = election(n_age, n_age_edu, n_edu, n_region_full, n_state,
                        age, age_edu, black, edu, female,
                        region_full, state, v_prev_full, y)
    prob_full  = DynamicPPL.LogDensityFunction(model_full)

    function sample_batch!(_prob)
        if isempty(data_itr)
            data_itr = Iterators.partition(shuffle(rng, data_idx), n_batch)
        end
        batch_idx, data_itr = Iterators.peel(data_itr)

        _prob = @set _prob.model.args.age     = data["age"][        batch_idx]
        _prob = @set _prob.model.args.age_edu = data["age_edu"][    batch_idx]
        _prob = @set _prob.model.args.black   = data["black"][      batch_idx]
        _prob = @set _prob.model.args.edu     = data["edu"][        batch_idx]
        _prob = @set _prob.model.args.female  = data["female"][     batch_idx]
        _prob = @set _prob.model.args.region  = data["region_full"][batch_idx]
        _prob = @set _prob.model.args.state   = data["state"][      batch_idx]
        _prob = @set _prob.model.args.v_prev  = data["v_prev_full"][batch_idx]
        _prob = @set _prob.model.args.y       = data["y"][          batch_idx]
        _prob = @set _prob.context.loglike_scalar  = n_data/length(batch_idx)
        _prob
    end

    function compute_full_elbo!(q, b⁻¹, M)
        zs, ∑logdetjac = rand_and_logjac(rng, q, b⁻¹, M)

        𝔼ℓjoint = mapreduce(+, 1:M) do m
            LogDensityProblems.logdensity(prob_full, view(zs, :, m))
        end  / M

        𝔼ℓjoint + entropy(q) + ∑logdetjac/M
    end

    function validate(z)
        vi_new    = DynamicPPL.unflatten(prob.varinfo, prob.context, z)
        meta      = vi_new.metadata
        b_age     = meta.b_age.vals
        b_edu     = meta.b_edu.vals
        b_age_edu = meta.b_age_edu.vals
        b_state   = meta.b_state.vals
        b_region  = meta.b_region.vals
        β         = meta.β.vals

        y_hat = @. β[1] + β[2]*black +
            β[3]*female +
            β[5]*female*black +
            β[4]*v_prev_full +
            b_age[age] + b_edu[edu] + b_age_edu[age_edu] + b_state[state] + b_region[region_full];

        y_pred = logistic.(y_hat) .> 0.5
        mean(y .== y_pred)
    end

    prob, b⁻¹, sample_batch!, compute_full_elbo!, validate
end
