using IrrationalConstants, LogExpFunctions, SpecialFunctions
include("momentum.jl")

###########33
# momentum Refreshement
#############
stream(x, u) = (sin(2x + u) + one(x))/2
constant(x, u) = π/16
# stream_x(x, u) = (sin(2.0*x) + 1.0)/2.0
# stream_u(x, u) = (sin(π/4.0*u) + 1.0)/2.0

###############
# pseudo time shift
###############
time_shift(u::Real) = mod(u + π/16, one(u)) 
inv_timeshift(u::Real) = mod(u - π/16, one(u)) 



###################3
# rotation matrix
####################
# function rotation_mat(θ::Real)
# ```
# 2d-rotation matrix for counterclockwise rotation by angle θ
# ```
#     return  [cos(θ) -sin(θ); sin(θ) cos(θ)]
# end

# X = [0:0.1:100 ;]
# plot(X, [stream(x, i) for (x, i) in zip(X, [1:1001 ;])] )
# plot(rand(1001))

###############
# logsumexp function
##############
function logsumexp_sweep(X::Vector{T}, Ns::Vector{Int64}) where T <:Real
    @assert maximum(Ns) == size(X,1)
    L = zeros(size(Ns, 1))
    L[1] = LogExpFunctions.logsumexp(@view(X[1:Ns[1]]))
    @views for i = 2:size(Ns, 1)
        t = LogExpFunctions.logsumexp(X[Ns[i-1]+1:Ns[i]])
        l = LogExpFunctions.logsumexp([t, L[i-1]])
        L[i] = l
    end
    return L
end

# implement 
function cumlogsumexp(arr)
    n = length(arr)
    result = similar(arr)
    result[1] = arr[1]
    for i in 2:n
        result[i] = LogExpFunctions.logsumexp(result[i-1], arr[i])
    end
    return result
end



function logmeanexp(X::AbstractVector{T}) where {T<:Real}
    N = size(X, 1)
    return LogExpFunctions.logsumexp(X) - log(N)
end
function logmeanexp(X::AbstractMatrix{T}; dims = 1) where {T<:Real}
    N = size(X, dims)
    return vec(LogExpFunctions.logsumexp(X; dims = dims) .- log(N))
end

# function logmeanexp_slice(w; dims = d)
# ```
# logsumexp function works on a specific slice of array 
# ```
#     a = maximum(w, dims = dims)
#     wl = mean(expm1.(w .- a) .+ 1.0, dims = dims)
#     return a .+ log.(wl)
# end


function logsumexp_stream(X)
```
logsumexp function without memory allocation
adapt from "http://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html"
```
    alpha = -Inf
    r = 0.0
    for x ∈ X
        if x <= alpha
            r += exp(x - alpha)
        else
            r *= exp(alpha - x)
            r += 1.0
            alpha = x
        end
    end
    return log(r) + alpha
end

# n = 10_000
# X = 500.0*randn(n)

# @btime logsumexp($X)
# @btime logsumexp_stream($X)

