include("mf_setting.jl")
include("../../inference/NUTS/nuts.jl")

##########################3
# flow err scaling
##########################
Random.seed!(1)
nsample = 20
n_ref = 650
E_fwd = zeros(nsample, n_ref)
E_bwd = zeros(nsample, n_ref)
@threads for i in 1:nsample
    z0 = o.q_sampler(d) .* D .+ μ
    ρ0 = randn(d)
    u0 = rand()
    zs, ρs, us = MixFlow.flow_trace_fwd(
        o, ft.(ϵ), MixFlow.ref_coord, ft.(z0), ft.(ρ0), ft.(u0), n_ref
    )
    zz, ρρ, uu = MixFlow.flow_trace_fwd(o, ϵ, MixFlow.ref_coord, z0, ρ0, u0, n_ref)

    E_fwd[i, :] .=
        sqrt.(
            vec(
                sum(abs2, zs .- zz; dims=2) .+ sum(abs2, ρs .- ρρ; dims=2) .+
                (us .- uu) .^ 2,
            )
        )

    zbs, ρbs, ubs = MixFlow.flow_trace_bwd(
        o, ft.(ϵ), MixFlow.ref_coord, ft.(z0), ft.(ρ0), ft.(u0), n_ref
    )
    zzb, ρρb, uub = MixFlow.flow_trace_bwd(o, ϵ, MixFlow.ref_coord, z0, ρ0, u0, n_ref)
    E_bwd[i, :] .=
        sqrt.(
            vec(
                sum(abs2, zbs .- zzb; dims=2) .+ sum(abs2, ρbs .- ρρb; dims=2) .+
                (ubs .- uub) .^ 2,
            )
        )
end
JLD.save("result/flow_err.jld", "fwd_err", E_fwd, "bwd_err", E_bwd, "Ns", [1:n_ref;])

###############
# sampling err scaling
##############
Random.seed!(1)
nsample = 20
n_ref = 650
f1(x) = abs.(x)
f2(x) = sin.(x) .+ 1
f3(x) = 1 ./ (1 .+ exp.(-x))

F1 = zeros(nsample, n_ref)
F2 = zeros(nsample, n_ref)
F3 = zeros(nsample, n_ref)
Fs1 = zeros(nsample, n_ref)
Fs2 = zeros(nsample, n_ref)
Fs3 = zeros(nsample, n_ref)

prog_bar = ProgressMeter.Progress(
    nsample; dt=0.5, barglyphs=ProgressMeter.BarGlyphs("[=> ]"), barlen=50, color=:yellow
)
@threads for i in 1:nsample
    z0 = o.q_sampler(d) .* D .+ μ
    ρ0 = randn(d)
    u0 = rand()
    zs, ρs, us = MixFlow.flow_trace_fwd(
        o, ft.(ϵ), MixFlow.ref_coord, ft.(z0), ft.(ρ0), ft.(u0), n_ref
    )
    zz, ρρ, uu = MixFlow.flow_trace_fwd(o, ϵ, MixFlow.ref_coord, z0, ρ0, u0, n_ref)

    mf1, mfs1 = MCf(f1, zs, ρs, us, zz, ρρ, uu)
    mf2, mfs2 = MCf(f2, zs, ρs, us, zz, ρρ, uu)
    mf3, mfs3 = MCf(f3, zs, ρs, us, zz, ρρ, uu)
    F1[i, :] = mf1
    F2[i, :] = mf2
    F3[i, :] = mf3
    Fs1[i, :] = mfs1
    Fs2[i, :] = mfs2
    Fs3[i, :] = mfs3

    # update progress bar
    ProgressMeter.next!(prog_bar)
end
JLD.save(
    "result/sampling_err_rel.jld",
    "absx",
    (F1, Fs1),
    "sinx",
    (F2, Fs2),
    "sigmoid",
    (F3, Fs3),
    "Ns",
    [1:n_ref;],
)

################
# lpdf  err scaling
###############
Random.seed!(1)
nsample = 100
# T, M, U = MixFlow.Sampler(o, a, MixFlow.ref_coord, 1500; nsample=nsample)
T = nuts(μ, 0.7, logp, ∇logp, 10000 + nsample, 10000)[10001:end, :]
M = randn(nsample, d)
U = rand(nsample)
n_ref = 650
Ds = zeros(nsample, n_ref)
Dd = zeros(nsample, n_ref)

# lpdf_est, Error
prog_bar = ProgressMeter.Progress(
    nsample; dt=0.5, barglyphs=ProgressMeter.BarGlyphs("[=> ]"), barlen=50, color=:yellow
)
@threads for i in 1:nsample
    z0 = T[i, :]
    m0 = M[i, :]
    u0 = U[i]
    Ds[i, :] = MixFlow.log_density_cum(z0, m0, u0, o, a, MixFlow.inv_ref_coord, n_ref)
    Dd[i, :] = MixFlow.log_density_cum(
        ft.(z0), ft.(m0), ft.(u0), o, a_big, MixFlow.inv_ref_coord, n_ref
    )
    # update progress bar
    ProgressMeter.next!(prog_bar)
end
JLD.save("result/lpdfs_err.jld", "lpdfs", Ds, "lpdfs_big", Dd)

# ##########################3
# # ELBO 
# ##########################
Random.seed!(1)
Ns = [10, 20, 50, 100, 200, 400, 650]
el_size = 200
EL = MixFlow.elbo_sweep_multiple(
    o, a, MixFlow.ref_coord, MixFlow.inv_ref_coord, Ns; elbo_size=el_size
)
EL_big = MixFlow.elbo_sweep_multiple(
    o, a_big, MixFlow.ref_coord, MixFlow.inv_ref_coord, Ns; elbo_size=el_size
)
JLD.save("result/elbo_err.jld", "Ns", Ns, "EL", EL, "EL_big", EL_big)
