
function unflatten_and_stack(zs::Array, prob, filter_keys = collect(keys(prob.varinfo.metadata)))
    n_samples = size(zs, 2)
    samples = map(eachcol(zs)) do z
        vi′ = DynamicPPL.unflatten(prob.varinfo, prob.context, z)
        ks  = keys(vi′.metadata)
        Dict(k => vi′.metadata[k].vals for k ∈ ks )
    end
    
    z₀ = first(samples)
    ks = collect(keys(z₀))
    ks = filter(k -> k ∈ filter_keys, ks)

    map(ks) do k
        n_dims = length(size(z₀[k]))
        buf    = Array{eltype(zs)}(undef, size(z₀[k])..., n_samples)
        for n = 1:n_samples
            if n_dims == 1
                buf[:,n] = samples[n][k]
            elseif n_dims == 2 
                buf[:,:,n] = samples[n][k]
            end
        end
        k => buf
    end |> Dict
end

function rand_and_logjac(rng::Random.AbstractRNG, q, b⁻¹, n_samples::Int)
    ηs = rand(rng, q, n_samples)
    ∑logdetjac = zero(eltype(ηs))

    zs = mapslices(ηs, dims=1) do ηₘ
        z, logdetjac = Bijectors.with_logabsdet_jacobian(b⁻¹, ηₘ)
        ∑logdetjac  += logdetjac
        z
    end
    zs, ∑logdetjac
end

