import Arpack
import Distributions
import LinearAlgebra: Diagonal, Symmetric, qr, eigen
import StatsBase

"""
  covariance_geom_decay(d::Int, r::Int, gap::Float64, stable_rank::Float64) -> (evecs, evals)

Generate a covariance matrix for a Gaussian distribution whose principal
eigenvalues decay geometrically after the `r`-th one and the eigenvectors
are sampled at random among an orthogonal matrices.

Returns the eigenvectors and eigenvalues of the covariance matrix.
"""
function covariance_geom_decay(
  d::Int,
  r::Int,
  gap::Float64,
  stable_rank::Float64,
)
  @assert (stable_rank > r) "Stable rank must be higher than numerical rank."
  α = 1 - ((1 - gap) / (stable_rank - r))
  evals = [ones(Float64, r); (1 - gap) .* (α .^ (0:(d - r - 1)))]
  evecs = Matrix(qr(randn(d, d)).Q)
  return evecs, Diagonal(evals)
end


"""
  generate_samples(D::Distributions.AbstractMvNormal, m::Int, n::Int, r::Int)

Generate all local samples given a distribution `D`. The function first draws
`n` samples per machine out of `m` machines and computes the `r` principal
eigenvectors of every local empirical covariance matrix.

Returns a `d × m × r` tensor `evlocal`, where slice `evlocal[:, i, :]` contains
the `d × r` eigenvector matrix of machine `i`.
"""
function generate_samples(
  D::Distributions.AbstractMvNormal,
  m::Int,
  n::Int,
  r::Int,
)
  d = length(D)
  samples = reshape(rand(D, m * n), d, n, m)
  get_cov = X -> (1 / n) * StatsBase.scattermat(X, dims=2, mean=nothing)
  # Compute all local covariance matrices.
  covmats = mapslices(get_cov, samples, dims=(1, 2))
  # Get leading r eigenvectors from every matrix.
  evlocal = zeros(d, m, r)
  @inbounds for i in 1:m
    eiglocal = eigen(Symmetric(covmats[:, :, i]), (d - r + 1):d)
    evlocal[:, i, :] = eiglocal.vectors
  end
  return evlocal
end

"""
  contaminate_samples!(Ws::Array{Float64, 3}, α::Float64)

Contaminate an `α` fraction of samples from `Ws`. The contaminated samples
are replaced by the same matrix with orthonormal columns. By default, this
function will always contaminate the first `⌊α * m⌋` samples.
"""
function contaminate_samples!(Ws::Array{Float64, 3}, α::Float64)
  d, m, r = size(Ws)
  n_corr = Int(floor(α * m))
  V_corr = Matrix(qr(randn(d, r)).Q)
  @inbounds for i in 1:n_corr
    Ws[:, i, :] = V_corr
  end
  return Ws
end
