
abstract type ELBOEstimator end

struct ClosedFormEntropy{IsProximal} <: ELBOEstimator end

function elbo(rng, ::ClosedFormEntropy{false},
              logdensityprob, λ, q_stop, ψ⁻¹, ϕ, φ, d, M, param_type, unflatten )
    m, s, L_low = unflatten(λ)

    ℍφ = entropy(φ)
    L  = construct_scale(Val(param_type), s, L_low, ϕ)
    u  = rand(rng, φ, d, M)
    ηs = L*u .+ m

    𝔼ℓ = mapreduce(+, 1:M) do m
        zₘ, ℓabsdetψ⁻¹ = Bijectors.with_logabsdet_jacobian(ψ⁻¹, ηs[:,m])
        LogDensityProblems.logdensity(logdensityprob, zₘ) + ℓabsdetψ⁻¹
    end / M

    ℍ = d*ℍφ + first(logabsdet(L))
    𝔼ℓ + ℍ
end

function elbo(rng, ::ClosedFormEntropy{true},
              logdensityprob, λ, q_stop, ψ⁻¹, ϕ, φ, d, M, param_type, unflatten)
    m, s, L_low = unflatten(λ)

    ℍφ = entropy(φ)
    L  = construct_scale(Val(param_type), s, L_low, ϕ)
    u  = rand(rng, φ, d, M)
    ηs = L*u .+ m

    𝔼ℓ = mapreduce(+, 1:M) do m
        zₘ, ℓabsdetψ⁻¹ = Bijectors.with_logabsdet_jacobian(ψ⁻¹, ηs[:,m])
        LogDensityProblems.logdensity(logdensityprob, zₘ) + ℓabsdetψ⁻¹
    end / M
    𝔼ℓ
end

struct StickingTheLanding{IsProximal} <: ELBOEstimator end

function elbo(rng, ::StickingTheLanding{false},
              logdensityprob, λ, q_stop, ψ⁻¹, ϕ, φ, d, M, param_type, unflatten)
    m, s, L_low = unflatten(λ)

    L  = construct_scale(Val(param_type), s, L_low, ϕ)
    u  = rand(rng, φ, d, M)
    ηs = L*u .+ m

    mapreduce(+, 1:M) do m
        zₘ, ℓabsdetψ⁻¹ = Bijectors.with_logabsdet_jacobian(ψ⁻¹, ηs[:,m])
        ℓπₘ = LogDensityProblems.logdensity(logdensityprob, zₘ)
        ℓqₘ = logpdf(q_stop, ηs[:,m])
        ℓπₘ - ℓqₘ + ℓabsdetψ⁻¹
    end / M
end

function elbo(rng, ::StickingTheLanding{true},
              logdensityprob, λ, q_stop, ψ⁻¹, ϕ, φ, d, M, param_type, unflatten)
    m, s, L_low = unflatten(λ)

    ℍφ = entropy(φ)
    L  = construct_scale(Val(param_type), s, L_low, ϕ)
    u  = rand(rng, φ, d, M)
    ηs = L*u .+ m

    elbo = mapreduce(+, 1:M) do m
        zₘ, ℓabsdetψ⁻¹ = Bijectors.with_logabsdet_jacobian(ψ⁻¹, ηs[:,m])
        ℓπₘ = LogDensityProblems.logdensity(logdensityprob, zₘ)
        ℓqₘ = logpdf(q_stop, ηs[:,m])
        ℓπₘ - ℓqₘ + ℓabsdetψ⁻¹
    end / M
    ℍ = d*ℍφ + first(logabsdet(L))
    elbo - ℍ
end

