using LinearAlgebra
using BenchmarkTools
using Base.Threads

g(P) = f( reverse(reverse(P,dims=1),dims=2))[end:-1:begin,end:-1:begin]
f(P) = cumsum(cumsum(P,dims=1),dims=2)
kl(P,Q) = sum( P .* log.( P ./ Q ) ) - sum(P) + sum(Q)
get_n_params(J) = 1 + sum(J.-1) + sum( (J[1:end-1].-1) .* (J[2:end].-1) ) + (J[1].-1) * (J[end].-1)

function update_prob_from_θ!(T,θ_1b,θ_2b, H_1b, H_2b)
    D = length(θ_1b)
    J = length.(θ_1b)

    for d=1:D
        θ_1b[d][1] = 0.0
        H_1b[d] = cumsum(θ_1b[d])
    end

    for d=1:D
        #θ_2b[d][1,:] .= 0
        #θ_2b[d][:,1] .= 0
        H_2b[d] = f(θ_2b[d])
    end

    for idx in CartesianIndices(T)
        term_1b = 0.0
        term_2b = 0.0
        for d = 1:D
            term_1b += H_1b[d][idx[d]]
            if d == D
                term_2b += H_2b[d][idx[1],idx[D]]
            else
                term_2b += H_2b[d][idx[d],idx[d+1]]
            end
        end
        T[idx] = term_1b + term_2b
    end

    T .= exp.(T)
    normalize!(T,1)
    return T
end

function split_1d_2d(θ,J)
    D = length(J)
    θ_1b = Vector{Vector{Float64}}(undef,D);
    θ_2b = Vector{Matrix{Float64}}(undef,D);
    for d=1:D
        θ_1b[d] = Vector{Float64}(undef,J[d])
        if d == D
            θ_2b[d] = zeros(J[1],J[D])
        else
            θ_2b[d] = zeros(J[d],J[d+1])
        end
    end
    return split_1d_2d!(θ_1b,θ_2b,θ,J)
end

function split_1d_2d!(θ_1b,θ_2b,θ,J)

    #######################
    ## Get One-body theta #
    #######################
    D = length(J)
    cumsumJ = cumsum(J)
    for d = 1:D
        θ_1b[d][1] = θ[1]
        if d == 1
            θ_1b[d][2:end] .= θ[2:J[d]]
        else
            θ_1b[d][2:end] .= θ[cumsumJ[d-1]+3-d:cumsumJ[d]-(d-1)]
        end
    end

    #######################
    ## Get Two-body theta #
    #######################

    m = sum(J) + 1 - D # number of zero and one-body parameters
    q = 1
    for d = 1:D
        if d != D
            cyc = [d,d+1]
        else
            cyc = [1,d]
        end

        if d == 1
            begin_idx = 1+m
            end_idx = m+(J[cyc[1]]-1)*(J[cyc[2]]-1)
        else
            begin_idx = q
            end_idx = q+(J[cyc[1]]-1)*(J[cyc[2]]-1)-1
        end

        θ_part = θ[begin_idx:end_idx]
        n = 1
        for j = 2:J[cyc[2]]
            for i = 2:J[cyc[1]]
                θ_2b[d][i,j] = θ_part[n]
                n += 1
            end
        end
        q += (J[cyc[1]]-1)*(J[cyc[2]]-1)
        if d == 1
            q += m
        end
    end

    return θ_1b,θ_2b
end

function get_M(J)
    D = length(J)
    n_params = get_n_params(J)
    MCI = Vector{CartesianIndex{D}}(undef,n_params)

    M = ones(Int16,D)
    MCI[1] = CartesianIndex(M...)
    t = 2
    for a = 1:D, b = 2:J[a]
        M = ones(Int16,D)
        M[a] = b
        MCI[t] = CartesianIndex(M...)
        t += 1
    end

    for d =1:D
        if d == D
            cyc = [1,D]
        else
            cyc = [d,d+1]
        end
        for q = 2:J[cyc[2]]
            for r = 2:J[cyc[1]]
                M = ones(Int16,D)
                M[cyc[1]] = r
                M[cyc[2]] = q
                MCI[t] = CartesianIndex(M...)
                t += 1
            end
        end
    end
    return MCI
end

function update_G!(G,η,M)
    n_params = length(M)
    @inbounds for v = 1:n_params
        idxv = M[v]
        for u = 1:v
            idxu = M[u]
            idx = max(idxu,idxv)
            G.data[u,v] = η[idx] - η[idxu]*η[idxv]
        end
    end
    return G
end

function get_η(T)
    D = ndims(T)
    η = copy(T)

    η′ = reshape(@view(η[end:-1:1]), size(η))
    for d = 1:D
        cumsum!(η′, η′, dims=d)
    end
    η
end

function get_η_vec!(η_vec,η,M)
    n_params = length(M)
    for u = 1:n_params
        idx = M[u]
        η_vec[u] = η[idx]
    end
    η_vec
end

function get_η_vec(η,M)
    n_params = length(M)
    η_vec = Vector{Float64}(undef,n_params)
    get_η_vec!(η_vec,η,M)
end

function b2_decomp(T;newton=true, tmax=10, error_tol=1.0e-5, lr=0.01,verbose=false)
    sum_input = sum(T)
    normalize!(T,1)
    J = size(T)
    M = get_M(J)
    D = ndims(T)
    η_goal = get_η(T)
    η_goal_vec = get_η_vec(η_goal,M)

    # Initialize Parameters
    θ_vec = zeros(length(η_goal_vec))
    η_vec = zeros(length(η_goal_vec))
    θ_1b = Vector{Vector{Float64}}(undef,D);
    θ_2b = Vector{Matrix{Float64}}(undef,D);
    for d = 1:D
        θ_1b[d] = zeros(J[d])
        if d == D
            θ_2b[d] = zeros(J[1],J[D])
        else
            θ_2b[d] = zeros(J[d],J[d+1])
        end
    end

    # Initialize Energy
    H_1b = Vector{Vector{Float64}}(undef,D);
    H_2b = Vector{Matrix{Float64}}(undef,D);

    # Initialize FIM
    G = Symmetric(Matrix{Float64}(undef, length(η_goal_vec), length(η_goal_vec)))

    # Initialize Prob (uniform dist)
    Tt = ones(Float64, J...) ./ prod(J)

    res_old = 0.0
    for t = 1:tmax
        ηt = get_η(Tt)
        if t == 1
            η_vec = get_η_vec(ηt,M)
        else
            get_η_vec!(η_vec,ηt,M)
        end

        if newton
            update_G!(G,ηt,M)
            θ_vec[2:end] .-= G[2:end,2:end] \ (η_vec[2:end] .- η_goal_vec[2:end] )
        else
            θ_vec[2:end] .-= lr.*(η_vec[2:end] .- η_goal_vec[2:end] )
        end
        split_1d_2d!(θ_1b,θ_2b,θ_vec,J);
        update_prob_from_θ!(Tt,θ_1b, θ_2b, H_1b, H_2b)

        res = norm(η_vec - η_goal_vec)
        if res_old > eps() && res > res_old
            break
        end

        if res < error_tol
            break
        else
            if verbose
                @show t, res
            end
            res_old = res
        end
    end

    return Tt .* sum_input
end
