
using HDF5

Turing.@model function spectral(y, txω, z)
    N   = size(txω, 2)

    α_d = 1e-2
    β_d = 1e-2
    α_λ = 1e-2
    β_λ = 1e-2

    d   ~ Gamma(α_d, 1/β_d)
    λ⁻¹ ~ InverseGamma(α_λ, 1/β_λ)
    τ   ~ Turing.Truncated(Cauchy(0, 1), 0, Inf)
    λ   ~ Turing.filldist(Turing.Truncated(Cauchy(0, 1), 0, Inf), N)
    ϕ   ~ Turing.filldist(Uniform(-π, π), N)
    a   ~ MvNormal(zeros(N) , τ*λ)

    # t ∈ ℝ^{T}; ωxt ∈ ℝ^{N × T}
    μ = cos.(ϕ' .- txω)*a
    z = (μ - y) / λ⁻¹
    @. z ~ TDist(d)
end

function time_series_dataset(::Val{:ligo})
    io    = HDF5.h5open(datadir("datasets", "H-H1_GWOSC_4KHZ_R1-1126257415-4096.hdf5"), "r")
    #io    = HDF5.h5open(datadir("datasets", "H-H1_LOSC_4_V2-1126257414-4096.hdf5"), "r")
    y     = read(io, "strain/Strain")
    y_std = y / std(y)
end

function model_with_dataset(::Val{:spectral}, dataset, batchsize; rng=Random.GLOBAL_RNG)
    x      = time_series_dataset(Val(:ligo))[1:2000]
    n_data = length(x)
    fs     = 4e+3
    t      = collect(0:n_data-1)*fs

    N     = 256
    f_min = 10
    f_max = 2e+3
    ω     = range(f_min, f_max; length=N)

    batchsize = min(batchsize, n_data)

    data_idx  = 1:n_data
    data_itr  = Iterators.partition(data_idx, batchsize)
    batch_idx = first(data_itr)

    x_dev = x |> Array{Float32} |> CuArray
    large_data_itr = Iterators.partition(data_idx, 100_000)

    context = DynamicPPL.MiniBatchContext(; batch_size=batchsize, npoints=n_data)
    model   = spectral(y[batch_idx], t[batch_idx].*ω', NaN)

    varinfo = DynamicPPL.VarInfo(model)
    b       = Bijectors.bijector(model)
    b⁻¹     = inverse(b)
    prob    = DynamicPPL.LogDensityFunction(model, varinfo, context)

    prior_ctxt = DynamicPPL.PriorContext()
    prior_prob = DynamicPPL.LogDensityFunction(model, varinfo, prior_ctxt)

    function sample_train_batch!(prob)
        if isempty(data_itr)
            data_itr = Iterators.partition(shuffle(rng, data_idx), batchsize)
        end
        batch_idx, data_itr = Iterators.peel(data_itr)

        #prob = @set prob.model.args.batch_idx   = batch_idx
        prob = @set prob.model.args.y   = y[batch_idx]
        prob = @set prob.model.args.txω = t[batch_idx].*ω'
        prob = @set prob.context.loglike_scalar = n_data/length(batch_idx)
        prob 
    end

    function compute_full_elbo!(q, b⁻¹, M)
        # zs, ∑logdetjac = rand_and_logjac(rng, q, b⁻¹, M)

        # 𝔼ℓprior = mapreduce(+, 1:M) do m
        #     LogDensityProblems.logdensity(prior_prob, view(zs, :, m))
        # end / M

        # samples = unflatten_and_stack(zs, prob, [:θ, :σ⁻¹, :d])
        # θ_dev   = samples[:θ]      |> Array{Float32} |> CuArray
        # σ⁻¹_dev = samples[:σ⁻¹]    |> Array{Float32} |> CuArray
        # ν_dev   = samples[:d]      |> Array{Float32} |> CuArray
        # ν       = samples[:d][1,:] 

        # νp12 = @. (ν + 1) / 2
        # ∑ₘℓZ = n_data*sum(
        #     @. StatsFuns.loggamma(νp12) - (logπ + log(ν))/2 - StatsFuns.loggamma(ν/2))

        # νp12_dev       = @. (ν_dev + 1) / 2
        # ∑ₘℓlike_unnorm = mapreduce(+, large_data_itr) do batch_idx
        #     X_auto_batch_dev = x_dev[batch_idx .+ (-p:-1)']
        #     x_batch_dev      = x_dev[batch_idx]
        #     z_batch_dev      = (X_auto_batch_dev*θ_dev .- reshape(x_batch_dev, (:,1))) .* σ⁻¹_dev
        #     -sum(log1p.((z_batch_dev.^2)./ν_dev).*νp12_dev)
        # end
        # 𝔼ℓlike = ∑ₘℓlike_unnorm/M + ∑ₘℓZ/M
        # 𝔼ℓlike + 𝔼ℓprior + entropy(q) + ∑logdetjac/M
        0
    end

    function validate(q, b⁻¹)
        n_samples = 100
        y_pred    = mapreduce(+, 1:n_samples) do _
            z        = b⁻¹(rand(rng, q))
            vi_new   = DynamicPPL.unflatten(prob.varinfo, prob.context, z)
            metadata = vi_new.metadata
            θ        = metadata[Symbol("θ")].vals
            map(valid_idx) do n 
                @inbounds μₙ = dot(view(x, n-p:n-1), θ)
            end
        end / n_samples
        sqrt(mean((x[valid_idx] - y_pred).^2))
    end

    prob, b⁻¹, sample_train_batch!, compute_full_elbo!, validate
end
