
Turing.@model function autoregress(x, X_auto, p, ::Any)
    #=
    Cuevas, Alejandro, et al.
    "Bayesian autoregressive spectral estimation."
    IEEE Latin American Conference on Computational Intelligence (LA-CCI). 2021.
    =##
    σₑ = 0.5
    σ² ~ Turing.TruncatedNormal(0, σₑ, 0, Inf)
    a  ~ MvNormal(zeros(p) , σ²*I)

    x ~ MvNormal(X_auto*a, sqrt(σ²)*I)
end

Turing.@model function autoregress_robust(x, X_auto, p, z)
    #=
    Christmas, Jacqueline, and Richard Everson.
    "Robust autoregression: Student-t innovations using variational Bayes."
    IEEE Transactions on Signal Processing 59.1 (2010): 48-57.
    =##

    α_δ = 1e-2
    β_δ = 1e-2
    α_d = 1e-2
    β_d = 1e-2
    α_λ = 1e-2
    β_λ = 1e-2

    d   ~ Gamma(α_d, 1/β_d)
    λ⁻² ~ InverseGamma(α_λ, 1/β_λ)
    δ⁻² ~ Turing.filldist(InverseGamma(α_δ, 1/β_δ), p)
    θ   ~ MvNormal(zeros(p), sqrt.(δ⁻²))

    z = (X_auto*θ - x) / sqrt(λ⁻²)
    @. z ~ TDist(d)
end

Turing.@model function autoregress_sparse(x, X_auto, p, z)
    #=
    Christmas, Jacqueline, and Richard Everson.
    "Robust autoregression: Student-t innovations using variational Bayes."
    IEEE Transactions on Signal Processing 59.1 (2010).

    Carvalho, Carlos M., Nicholas G. Polson, and James G. Scott.
    "Handling sparsity via the horseshoe."
    Artificial intelligence and statistics. PMLR, 2009.

    Carvalho, Carlos M., Nicholas G. Polson, and James G. Scott.
    "The horseshoe estimator for sparse signals."
    Biometrika 97.2 (2010): 465-480.
    =##

    α_d = 1e-2
    β_d = 1e-2
    α_σ = 1e-2
    β_σ = 1e-2

    d   ~ Gamma(α_d, 1/β_d)
    σ⁻¹ ~ InverseGamma(α_σ, 1/β_σ)

    τ ~ Turing.Truncated(Cauchy(0, 1), 0, Inf)
    λ ~ Turing.filldist(Turing.Truncated(Cauchy(0, 1), 0, Inf), p)
    θ ~ MvNormal(zeros(p), τ*λ)
    d_safe = d + eps(eltype(X_auto))

    z = (X_auto*θ - x)*σ⁻¹
    @. z ~ TDist(d_safe)
end

function read_int24(io, N)
    buf = read(io, 3*N)
    x = Vector{Int32}(undef, N)
    for i=1:length(x)
        @inbounds x[i] = reinterpret(Int32, (UInt32(buf[3*(i-1)+1]) << 24) |
                                            (UInt32(buf[3*(i-1)+2]) << 16) |
                                            (UInt32(buf[3*(i-1)+3]) << 8) ) >> 8
    end
    x
end

function time_series_dataset(::Val{:ecg})
    #=
    Goldberger, A., et al. (2000).
    "PhysioBank, PhysioToolkit, and PhysioNet: Components of a new research
    resource for complex physiologic signals."
    Circulation . 101 (23), pp. e215–e220.

    Jager, F., et al., (2003)
    "Long-term ST database: a reference for the development and evaluation
    of automated ischaemia detectors and for the study of the dynamics of
    myocardial ischaemia."
    Medical & Biological Engineering & Computing.

    https://physionet.org/content/ltstdb/1.0.0/

    The individual recordings of the Long-Term ST Database are between
    21 and 24 hours in duration, and contain two or three ECG signals.
    Each ECG signal has been digitized at 250 samples per second with
    12-bit resolution over a range of ±10 millivolts. Each record includes
    a set of meticulously verified ST episode and signal quality annotations,
    together with additional beat-by-beat QRS annotations and ST level measurements.
    =##

    # s20274.dat, 2064200 samples per channel
    # s20311.dat,
    # s20641.dat

    N    = 20642000
    io   = open(datadir("datasets", "timeseries", "s20274.dat"), "r")
    arr  = read_int24(io, 2*N)
    arr  = arr[1:N] / 2^23
    gain = 10
    arr*gain
end

function model_with_dataset(::Val{:autoregress}, dataset, batchsize; rng=Random.GLOBAL_RNG)
    x         = time_series_dataset(Val(:ecg))
    p         = 30
    n_data    = length(x) - p
    batchsize = min(batchsize, n_data)

    gpu_batchsize = 1_000

    data_idx  = p+1:length(x)
    data_itr  = Iterators.partition(data_idx, batchsize)
    batch_idx = first(data_itr)

    full_batch_itr = Iterators.Generator(
          Iterators.partition(data_idx, gpu_batchsize)) do _batch_idx
        X_auto_batch  = x[_batch_idx .+ (-p:-1)'] |> Array{Float32}
        x_batch       = x[_batch_idx]             |> Array{Float32}
        (X_auto_batch, x_batch)
    end |> CUDA.CuIterator

    context = DynamicPPL.MiniBatchContext(; batch_size=batchsize, npoints=n_data)
    model   = autoregress_sparse(x[batch_idx], x[batch_idx .+ (-p:-1)'], p, NaN)

    varinfo = DynamicPPL.VarInfo(model)
    b       = Bijectors.bijector(model)
    b⁻¹     = inverse(b)
    prob    = DynamicPPL.LogDensityFunction(model, varinfo, context)

    prior_ctxt = DynamicPPL.PriorContext()
    prior_prob = DynamicPPL.LogDensityFunction(model, varinfo, prior_ctxt)

    function sample_train_batch!(prob)
        if isempty(data_itr)
            data_itr = Iterators.partition(shuffle(rng, data_idx), batchsize)
        end
        batch_idx, data_itr = Iterators.peel(data_itr)

        #prob = @set prob.model.args.batch_idx   = batch_idx
        prob = @set prob.model.args.x      = x[batch_idx]
        prob = @set prob.model.args.X_auto = x[batch_idx .+ (-p:-1)']
        prob = @set prob.context.loglike_scalar = n_data/length(batch_idx)
        prob
    end

    function compute_full_elbo!(q, b⁻¹, M)
        zs, ∑logdetjac = rand_and_logjac(rng, q, b⁻¹, M)
        𝔼ℓprior = mapreduce(+, 1:M) do m
            LogDensityProblems.logdensity(prior_prob, view(zs, :, m))
        end / M

        samples = unflatten_and_stack(zs, prob, [:θ, :σ⁻¹, :d])
        θ_dev   = samples[:θ]      |> Array{Float32} |> CuArray
        σ⁻¹_dev = samples[:σ⁻¹]    |> Array{Float32} |> CuArray
        ν_dev   = samples[:d]      |> Array{Float32} |> CuArray
        ν       = samples[:d][1,:]

        νp12 = @. (ν + 1) / 2
        ∑ₘℓZ = n_data*sum(
            @. StatsFuns.loggamma(νp12) - (logπ + log(ν))/2 - StatsFuns.loggamma(ν/2))

        νp12_dev       = @. (ν_dev + 1) / 2
        #prog           = Progress(length(full_batch_itr))
        ∑ₘℓlike_unnorm = mapreduce(+, full_batch_itr) do (X_auto_batch_dev, x_batch_dev)
            z_batch_dev = (X_auto_batch_dev*θ_dev .- reshape(x_batch_dev, (:,1))) .* σ⁻¹_dev
            #next!(prog)
            -sum(log1p.((z_batch_dev.^2)./ν_dev).*νp12_dev)
        end
        GC.gc()

        𝔼ℓlike = ∑ₘℓlike_unnorm/M + ∑ₘℓZ/M
        𝔼ℓlike + 𝔼ℓprior + entropy(q) + ∑logdetjac/M
    end

    function validate(q, b⁻¹)
        n_samples = 100
        y_pred    = mapreduce(+, 1:n_samples) do _
            z        = b⁻¹(rand(rng, q))
            vi_new   = DynamicPPL.unflatten(prob.varinfo, prob.context, z)
            metadata = vi_new.metadata
            θ        = metadata[Symbol("θ")].vals
            map(valid_idx) do n
                @inbounds μₙ = dot(view(x, n-p:n-1), θ)
            end
        end / n_samples
        sqrt(mean((x[valid_idx] - y_pred).^2))
    end

    prob, b⁻¹, sample_train_batch!, compute_full_elbo!, validate
end
