using Functional
using Distributions
using StatsBase
using Parameters
using SparseArrays
using DataFrames
using DataFramesMeta
using Arpack
using Transducers
using LinearAlgebra

# Value function, stationary distribution, etc.
# ---------------------------------------------------------------------- #
function transition_mtx(mdp::BirthDeathMDP, a::Int64)
    P = spzeros(mdp.N, mdp.N)
    for s in 1:mdp.N
        pneg, pzero, ppos = increment_probs(mdp, s, a)
        P[s, max(s-1, 1)] += pneg
        P[s, s] += pzero
        P[s, min(s+1, mdp.N)] += ppos
    end
    P
end

function wls(X, y, w; λ=0.)
    Xx = X
    Xtw = Xx' .* w'
    (Xtw * Xx + λ * I) \ Xtw * y
end

function summarize_mdp(mdp; p_treat=0.5)
    P = p_treat * transition_mtx(mdp, 1) + (1 - p_treat) * transition_mtx(mdp, 2)
    summarize_mdp(P)
end

function stationary_distribution(P::SparseMatrixCSC{Float64, Int64})
    vals, vecs = eigs(P'; nev=2, ncv=40, maxiter=1000)
    real.(vecs[:, 1] ./ sum(vecs[:, 1]))
end

function summarize_mdp(P::AbstractMatrix)
    N = size(P, 1)
    ρ = stationary_distribution(P)
    r = rewards(P)
    V = hcat(I - P, ones(N, 1)) \ r
    (P=P, ρ=ρ, r=r, V=V[1:N], η=V[N+1])
end

# Summary functions
# ---------------------------------------------------------------------- #
    td0::TDState  # Running off-policy estimate of V_π0
    td1::TDState  # Running off-policy estimate of V_π1
    tdπ::TDState  # Running on-policy estimate of V_π{1/2}
    polyak_td0::TDState
    polyak_td1::TDState
    polyak_tdπ::TDState

PathStats(N) = PathStats(
    0,
    TDState(zeros(N), 0.),
    TDState(zeros(N), 0.),
    TDState(zeros(N), 0.),
    TDState(zeros(N), 0.),
    TDState(zeros(N), 0.),
    TDState(zeros(N), 0.),
    zeros(Float64, (N, 2)),
    spzeros(N, N),
    spzeros(N, N))

function polyak_average!(t, td, tdnew)
    V = td.V
    V *= t / (t + 1)
    V += 1 / (t + 1) * tdnew.V
    ηnew = t / (t + 1) * td.η + 1 / (t + 1) * tdnew.η
    TDState(V, ηnew)
end

function update_stats!(stats::PathStats, sarsa::SAR)
    @unpack t, td0, td1, tdπ, sum_rewards, control_transitions, treatment_transitions = stats
    @unpack s, a, r, snew = sarsa
    sum_rewards[s, a] += r
    if a == 1
        control_transitions[s, snew] += 1
    elseif a == 2
        treatment_transitions[s, snew] += 1
    end
    update!(π, td) = update_td!(π, [0.5, 0.5], td, state)
    PathStats(t + 1,
              sum_rewards,
              control_transitions,
              treatment_transitions)
end

# update!([1., 0.], td0),
# update!([0., 1.], td1),
# update!([0.5, 0.5], tdπ),
# polyak_average!(t, stats.polyak_td0, td0),
# polyak_average!(t, stats.polyak_td1, td1),
# polyak_average!(t, stats.polyak_tdπ, tdπ),

function visited_states(counts; thresh=0)
    out_counts = sum(counts; dims=2) |> vec
    in_counts = sum(counts; dims=1) |> vec
    (in_counts .> thresh) .& (out_counts .> thresh)
end

function dq_td(stats::PathStats, V::Vector{Float64})
    rewards = estimate_rewards(stats)
    counts = stats.control_transitions + stats.treatment_transitions
    is_visited = visited_states(counts)

    Pco = estimate_transition_matrix(
        stats.control_transitions[is_visited, is_visited]; fill=true)
    Qco = rewards[is_visited, 1] + Pco * V[is_visited]

    Ptr = estimate_transition_matrix(
        stats.treatment_transitions[is_visited, is_visited]; fill=true)
    # rtr = rewards(Ptr)
    Qtr = rewards[is_visited, 2] + Ptr * V[is_visited]

    state_counts = vec(sum(counts; dims=2))
    state_probs = state_counts ./ sum(state_counts)

    state_probs[is_visited]' * (Qtr - Qco)
end

function summarize_stats(stats::PathStats)
    control_counts = sum(stats.control_transitions)
    treatment_counts = sum(stats.treatment_transitions)
    mean_rewards = vec(sum(stats.sum_rewards; dims=1)) ./
        [control_counts, treatment_counts]
    Dict(:t => stats.t,
         :dq_td => dq_td(stats, stats.tdπ.V),
         :dq_td_polyak => dq_td(stats, stats.polyak_tdπ.V),
         :naive => mean_rewards[2] - mean_rewards[1],
         :off_policy => stats.td1.η - stats.td0.η,
         :reward_rate_co => stats.td0.η,
         :reward_rate_tr => stats.td1.η,
         :reward_rate_ex => stats.tdπ.η,
         :polyak_off_policy => stats.polyak_td1.η - stats.polyak_td0.η)
end

function estimate_transition_matrix(counts; fill=false, thresh=0)
    # Assume no absorbing states
    is_visited = visited_states(counts; thresh=thresh)
    out_counts = sum(counts; dims=2) |> vec
    # Weird way to write it, but it preserves sparsity.
    Phat = spdiagm(1 ./ out_counts[is_visited]) *
        counts[is_visited, is_visited]
    if fill
        N = size(out_counts, 1)
        Pfilled = spzeros(N, N)
        Pfilled[is_visited, is_visited] .= Phat
        Pfilled
    else
        Phat
    end
end

function model_based_estimator(stats::PathStats)
    control_results = stats.control_transitions |>
        estimate_transition_matrix |> summarize_mdp
    treatment_results = stats.treatment_transitions |>
        estimate_transition_matrix |> summarize_mdp
    treatment_results.η - control_results.η
end

function estimate_rewards(stats::PathStats)
    rs = stats.sum_rewards ./
        hcat(sum(stats.control_transitions; dims=2),
             sum(stats.treatment_transitions; dims=2))
    rs[isnan.(rs)] .= 0
    rs
end

function lstd(counts::SparseMatrixCSC{Int64}, sum_rewards::Vector{Float64})
    N = size(counts, 1)
    is_visited = visited_states(counts)
    state_counts = vec(sum(counts; dims=2))
    state_probs = state_counts ./ sum(state_counts)
    η = sum(sum_rewards) / sum(state_counts)
    r = vcat(0., (sum_rewards ./ state_counts)[is_visited])
    R = vcat(reshape(state_probs, (1, N))[:, is_visited],
             (diagm(state_counts) .- counts)[is_visited, is_visited])
    (R'R) \ (R' * (r .- η))
    R \ (r .- η)
end

# function lstd_estimator(stats::PathStats)
#     counts = stats.control_transitions + stats.treatment_transitions
#     state_probs = vec(sum(counts; dims=2)) ./ stats.t
#     is_visited = visited_states(counts)
#     V = lstd(counts, vec(sum(stats.sum_rewards; dims=2)))
#     rewards = estimate_rewards(stats)

#     Pco = estimate_transition_matrix(
#         stats.control_transitions[is_visited, is_visited]; fill=true)
#     # rco = rewards(Pco)
#     Qco = rewards[is_visited, 1] + Pco * V

#     Ptr = estimate_transition_matrix(
#         stats.treatment_transitions[is_visited, is_visited]; fill=true)
#     # rtr = rewards(Ptr)
#     Qtr = rewards[is_visited, 2] + Ptr * V

#     state_probs' * (Qtr .- Qco)
# end


function lstd_estimator(stats::PathStats; thresh=0)
    counts = stats.control_transitions + stats.treatment_transitions
    is_visited = visited_states(counts; thresh=thresh)
    summ = estimate_transition_matrix(counts; thresh=thresh) |>
        summarize_mdp
    r = estimate_rewards(stats)

    Pco = estimate_transition_matrix(
        stats.control_transitions[is_visited, is_visited];
        fill=true, thresh=thresh)
    # rco = rewards(Pco)
    Qco = r[is_visited, 1] + Pco * summ.V

    P_tr = estimate_transition_matrix(
        stats.treatment_transitions[is_visited, is_visited];
        fill=true, thresh=thresh)
    # rtr = rewards(Ptr)
    Qtr = r[is_visited, 2] + P_tr * summ.V

    summ.ρ' * (Qtr .- Qco)
end
