include("mf_setting.jl")
using JLD, LinearAlgebra, StatsBase, StatsPlots
using Plots, ProgressMeter

##########################
# flow error scaling
##########################
res = JLD.load("result/flow_err.jld")
plot(
    res["Ns"],
    vec(median(res["fwd_err"]'; dims=2)) .+ 1e-16;
    ribbon=get_percentiles(res["fwd_err"]' .+ 1e-16),
    lw=3,
    label="Fwd",
    xlabel="#transformations",
    ylabel="Error",
    title="MixFlow numerical error",
)
plot!(
    res["Ns"],
    vec(median(res["bwd_err"]'; dims=2)) .+ 1e-16;
    ribbon=get_percentiles(res["bwd_err"]' .+ 1e-16),
    lw=3,
    label="Bwd",
)
plot!(;
    yaxis=:log10,
    xlim=(0, 60),
    size=(800, 500),
    xtickfontsize=30,
    ytickfontsize=30,
    xrotation=0,
    margin=10Plots.mm,
    guidefontsize=30,
    titlefontsize=30,
    legendfontsize=18,
    legend=:bottomright,
)
savefig("figure/flow_err_log.png")

res = JLD.load("result/flow_err.jld")
plot(
    res["Ns"],
    vec(median(res["fwd_err"]'; dims=2)) .+ 1e-16;
    ribbon=get_percentiles(res["fwd_err"]' .+ 1e-16),
    lw=3,
    label="Fwd",
    xlabel="#transformations",
    ylabel="Error",
    title="MixFlow numerical error",
)
plot!(
    res["Ns"],
    vec(median(res["bwd_err"]'; dims=2)) .+ 1e-16;
    ribbon=get_percentiles(res["bwd_err"]' .+ 1e-16),
    lw=3,
    label="Bwd",
)
plot!(;
    size=(800, 500),
    xtickfontsize=30,
    ytickfontsize=30,
    xrotation=15,
    margin=10Plots.mm,
    guidefontsize=30,
    titlefontsize=30,
    legendfontsize=30,
    legend=:bottomright,
)
savefig("figure/flow_err.png")

##########################
# sampling error scaling
##########################

res = JLD.load("result/sampling_err_rel.jld")
Ns = res["Ns"]
F1, Fs1 = res["absx"]
F2, Fs2 = res["sinx"]
F3, Fs3 = res["sigmoid"]
r1 = abs.(average_run2(F1) .- average_run2(Fs1)) ./ abs.(average_run2(Fs1)) .+ 1e-16
r2 = abs.(average_run2(F2) .- average_run2(Fs2)) ./ abs.(average_run2(Fs2)) .+ 1e-16
r3 = abs.(average_run2(F3) .- average_run2(Fs3)) ./ abs.(average_run2(Fs3)) .+ 1e-16
p1 = plot(Ns, vec(median(r1'; dims=2)); ribbon=get_percentiles(r1'), lw=3, label="|x|")
plot!(Ns, vec(median(r2'; dims=2)); ribbon=get_percentiles(r2'), lw=3, label="sinx+1")
plot!(Ns, vec(median(r3'; dims=2)); ribbon=get_percentiles(r3'), lw=3, label="sigmoid")
plot!(;
    xlabel="#transformations",
    ylabel="Rel. err.",
    title="MixFlow sampling error",
    xrotation=20,
    legend=:bottomright,
)
plot!(;
    yaxis=:log10,
    yticks=[1e-14, 1e-11, 1e-8, 1e-5, 1e-2],
    size=(800, 500),
    xtickfontsize=30,
    ytickfontsize=30,
    margin=10Plots.mm,
    guidefontsize=30,
    legendfontsize=20,
    titlefontsize=30,
)

savefig("figure/sampling_err_log_rel.png")

##########################
# lpdf error scaling
##########################
res = JLD.load("result/lpdfs_err.jld")
# reshapw Ds Dd to 2d array where the last dimension is the original third dimension
Ds = res["lpdfs"]
Dd = res["lpdfs_big"]
Es = abs.(Ds .- Dd)'
Ns = [1:size(Ds, 2);]

p = plot(Ns, vec(median(Es; dims=2)); ribbon=get_percentiles(Es), lw=3, label="")
plot!(;
    xlabel="#transformations",
    ylabel="Error",
    title="MixFlow log-density error",
    xrotation=20,
    legend=:topright,
)
plot!(;
    size=(800, 500),
    xtickfontsize=30,
    ytickfontsize=30,
    margin=10Plots.mm,
    guidefontsize=30,
    legendfontsize=20,
    titlefontsize=30,
)

savefig("figure/lpdfs_err.png")

p = plot(Ns, vec(median(Es; dims=2)) .+ 1e-20; ribbon=get_percentiles(Es), lw=3, label="")
plot!(;
    xlabel="#transformations",
    ylabel="Error",
    title="MixFlow log-density error",
    xrotation=20,
    legend=:topright,
)
plot!(;
    yaxis=:log10,
    size=(800, 500),
    xtickfontsize=30,
    ytickfontsize=30,
    margin=10Plots.mm,
    guidefontsize=30,
    legendfontsize=20,
    titlefontsize=30,
)

savefig("figure/lpdfs_err_log.png")

res = JLD.load("result/lpdfs_err.jld")
# reshapw Ds Dd to 2d array where the last dimension is the original third dimension
Ds = res["lpdfs"]
Dd = res["lpdfs_big"]
Es = abs.(Ds .- Dd)' ./ abs.(Dd)'
Ns = [1:size(Ds, 2);]

p1 = plot(Ns, vec(median(Es; dims=2)); ribbon=get_percentiles(Es), lw=3, label="")
plot!(;
    xlabel="#transformations",
    ylabel="Rel. err.",
    title="MixFlow log-density error",
    xrotation=20,
    legend=:topright,
)
plot!(;
    size=(800, 500),
    xtickfontsize=30,
    ytickfontsize=30,
    margin=10Plots.mm,
    guidefontsize=30,
    legendfontsize=20,
    titlefontsize=30,
)

savefig("figure/lpdfs_err_rel.png")

p1 = plot(Ns, vec(median(Es; dims=2)) .+ 1e-20; ribbon=get_percentiles(Es), lw=3, label="")
plot!(;
    xlabel="#transformations",
    ylabel="Rel. err.",
    title="MixFlow log-density error",
    xrotation=20,
    legend=:topright,
)
# force yaxis to put enough 
plot!(;
    yaxis=:log10,
    yticks=[1e-18, 1e-15, 1e-12, 1e-9, 1e-6, 1e-3, 1e-1],
    size=(800, 500),
    xtickfontsize=30,
    ytickfontsize=25,
    margin=10Plots.mm,
    guidefontsize=30,
    legendfontsize=20,
    titlefontsize=30,
)

savefig("figure/lpdfs_err_log_rel.png")

##########################
# ELBO
##########################
res = JLD.load("result/elbo_err.jld")
Ns = res["Ns"]
EL = res["EL"]
EL_big = res["EL_big"]
# el_err = abs.(EL_big .- EL)'

p1 = plot(Ns, vec(mean(EL; dims=1)); lw=3, label="numerical")
plot!(Ns, vec(mean(EL_big; dims=1)); lw=3, label="exact")
plot!(;
    xlabel="#transformations",
    ylabel="ELBO",
    title="MixFlow ELBO est.",
    xrotation=20,
    legend=:bottomright,
)
plot!(;
    size=(800, 500),
    xtickfontsize=25,
    ytickfontsize=30,
    margin=10Plots.mm,
    guidefontsize=30,
    legendfontsize=20,
    titlefontsize=30,
)

savefig("figure/elbos.png")

# p = plot(Ns, vec(median(el_err; dims=2)); ribbon=get_percentiles(el_err), lw=3, label="")
# plot!(;
#     xlabel="#transformations",
#     ylabel="Error",
#     title="MixFlow ELBO est. err.",
#     xrotation=20,
#     legend=:topright,
# )
# plot!(;
#     size=(800, 500),
#     xtickfontsize=30,
#     ytickfontsize=30,
#     margin=10Plots.mm,
#     guidefontsize=30,
#     legendfontsize=20,
#     titlefontsize=30,
# )

# savefig("figure/elbo_err.png")

# p = plot(Ns, vec(median(el_err; dims=2)); ribbon=get_percentiles(el_err), lw=3, label="")
# plot!(;
#     xlabel="#transformations",
#     ylabel="Error",
#     title="MixFlow ELBO est. err.",
#     xrotation=20,
#     legend=:topright,
# )
# plot!(;
#     yaxis=:log10,
#     size=(800, 500),
#     xtickfontsize=30,
#     ytickfontsize=30,
#     margin=10Plots.mm,
#     guidefontsize=30,
#     legendfontsize=20,
#     titlefontsize=30,
# )

# savefig("figure/elbo_err_log.png")

# res = JLD.load("result/elbo_err.jld")
# Ns = res["Ns"]
# EL = res["EL"]
# EL_big = res["EL_big"]
# el_err = abs.(EL_big .- EL)' ./ abs.(EL_big)'

# p1 = plot(Ns, vec(median(el_err; dims=2)); ribbon=get_percentiles(el_err), lw=3, label="")
# plot!(;
#     xlabel="#transformations",
#     ylabel="Error",
#     title="MixFlow ELBO est. err.",
#     xrotation=20,
#     legend=:topright,
# )
# plot!(;
#     size=(800, 500),
#     xtickfontsize=30,
#     ytickfontsize=30,
#     margin=10Plots.mm,
#     guidefontsize=30,
#     legendfontsize=20,
#     titlefontsize=30,
# )

# savefig("figure/elbo_err_rel.png")

###################
# one step error
######################
# box plot E_fwd and E_bwd
E_fwd = vec(JLD.load("result/delta.jld")["E_fwd"])
E_bwd = vec(JLD.load("result/delta.jld")["E_bwd"])
p1 = boxplot(
    ["Fwd err." "Bwd err."], [E_fwd E_bwd]; legend=false, title="MixFlow single map err."
)
plot!(; ylabel="Error", yaxis=:log10)
plot!(;
    size=(800, 500),
    # ylims=(1e-16, 1e-11),
    # yticks=[1e-15, 1e-14, 1e-12, 1e-10],
    xtickfontsize=30,
    ytickfontsize=30,
    margin=10Plots.mm,
    guidefontsize=30,
    legendfontsize=20,
    titlefontsize=30,
)
savefig(p1, joinpath("figure/", "delta.png"))

##########################
# shadowing window
###########################
res = JLD.load("result/windows.jld")
Ns = res["Ns"]
W_fwd = res["W_fwd"]
W_bwd = res["W_bwd"]
T_fwd = res["T_fwd"]
T_bwd = res["T_bwd"]

p = plot(Ns, vec(median(W_fwd'; dims=2)); ribbon=get_percentiles(W_fwd'), lw=3, label="Fwd")
plot!(Ns, vec(median(W_bwd'; dims=2)); ribbon=get_percentiles(W_bwd'), lw=3, label="Bwd")
plot!(;
    xlabel="#transformations",
    ylabel="",
    title="MixFlow ϵ size",
    xrotation=20,
    legend=:topleft,
)
plot!(;
    size=(900, 500),
    xtickfontsize=30,
    ytickfontsize=30,
    margin=10Plots.mm,
    guidefontsize=30,
    legendfontsize=20,
    titlefontsize=30,
)

savefig("figure/shadowing.png")

p = plot(Ns, vec(median(T_fwd'; dims=2)); ribbon=get_percentiles(T_fwd'), lw=3, label="Fwd")
plot!(Ns, vec(median(T_bwd'; dims=2)); ribbon=get_percentiles(T_bwd'), lw=3, label="Bwd")
plot!(;
    xlabel="#transformations",
    ylabel="Wall time in sec",
    title="ϵ computation time",
    xrotation=20,
    legend=:topleft,
)
plot!(;
    size=(800, 500),
    xtickfontsize=30,
    ytickfontsize=30,
    margin=10Plots.mm,
    guidefontsize=30,
    legendfontsize=20,
    titlefontsize=30,
)

savefig("figure/shadowing_time.png")
