import LinearAlgebra
import StatsBase

"""
  dist(U::Matrix{Float64}, V::Matrix{Float64})

Compute the subspace distance between two matrices `U` and `V` with orthonormal
columns.
"""
function dist(U::Matrix{Float64}, V::Matrix{Float64})
  return LinearAlgebra.opnorm(V - U * (U'V))
end

"""
  procrustes_align(V::Matrix{Float64}, U::Matrix{Float64})

Solve the Procrustes alignment problem

  \\min_{Z: Z'Z = I_r} \\| VZ - U \\|_F

for a pair of `d × r` orthogonal matrices `V` and `U`.

Returns the aligned matrix `V Zopt`, where `Zopt` is the solution of the
alignment problem.
"""
function procrustes_align(V::Matrix{Float64}, U::Matrix{Float64})
  svd_obj = LinearAlgebra.svd(V'U)
  return V * (svd_obj.U * svd_obj.Vt)
end

"""
  procrustes_fixing(Ws::Array{Float64, 3}; n_iter::Int = 1)

Given a `d x m x r` tensor holding `m` orthogonal matrices of size `d × r`,
align all of them with the first matrix by solving a Procrustes problem and
return the aligned samples in another `d × m × r` tensor.

If `n_iter > 1`, repeats the above procedure for `n_iter` steps using the
solution of the previous step as the next reference solution.
"""
function procrustes_fixing(Ws::Array{Float64, 3}; n_iter::Int = 1)
  mapped = mapslices(
    X -> procrustes_align(X, Ws[:, 1, :]),
    Ws,
    dims=(1, 3),
  )
  for _ in 2:n_iter
    mapped = mapslices(
      X -> procrustes_align(X, StatsBase.mean(mapped, dims=2)[:, 1, :]),
      mapped,
      dims=(1, 3),
    )
  end
  return mapped
end

"""
  pairwise_dist(Ws::Array{Float64, 3})

Given a `d × r × m` tensor `Ws` holding `m` orthogonal matrices of size `d × r`,
compute the subspace distance between every pair of them. Collect the results in
a `m × m` matrix.
"""
function pairwise_dist(Ws::Array{Float64, 3})
  num_samples = size(Ws, 2)
  distances = zeros(num_samples, num_samples)
  @inbounds for i = 1:num_samples
    # Cache to avoid two accesses per inner loop.
    W_i = Ws[:, i, :]
    @inbounds for j = (i+1):num_samples
      distances[i, j] = dist(W_i, Ws[:, j, :])
    end
  end
  return distances + distances'
end

"""
  robust_reference_selection(distances::Matrix{Float64})

Given an `m × m` matrix of pairwise distances between samples, select a sample
as reference robustly.

Returns the index of the sample chosen as a reference.
"""
function robust_reference_selection(distances::Matrix{Float64})
  # Find the (median+1)-st element for each row.
  k = (size(distances, 1) ÷ 2) + 1
  ϵ = minimum(mapslices(x -> partialsort(x, k), distances, dims=2))
  return argmax(sum(distances .≤ ϵ, dims=2)[:])
end

"""
  procrustes_fixing_robust(Ws::Array{Float64, 3}; n_iter::Int = 1)

Run the procrustes alignment procedure using robust reference estimation given
a tensor `Ws` of size `d × m × r` holding `m` matrices with orthonormal
columns.

If `n_iter > 1`, repeats the above procedure for `n_iter` steps, using the
solution of the previous step as the next reference.

Returns a tensor of size `d × m × r` holding the aligned samples.
"""
function procrustes_fixing_robust(Ws::Array{Float64, 3}; n_iter::Int = 1)
  ref_ind = robust_reference_selection(pairwise_dist(Ws))
  @debug "Choosing ref_ind = $(ref_ind) for robust refence."
  mapped = mapslices(
    X -> procrustes_align(X, Ws[:, ref_ind, :]),
    Ws,
    dims=(1, 3),
  )
  for _ in 2:n_iter
    mapped = mapslices(
      X -> procrustes_align(X, StatsBase.mean(mapped, dims=2)[:, 1, :]),
      mapped,
      dims=(1, 3),
    )
  end
  return mapped
end
