import numpy as np

from sklearn.base import BaseEstimator, TransformerMixin


class TangentSpace(BaseEstimator, TransformerMixin):
    def __init__(self, rank, ref=None):
        """Init."""
        self.rank = rank
        self.ref = ref

    def fit(self, X, y=None):
        ref = self.ref
        if ref is None:
            ref = np.mean(X, axis=0)
        Y = to_quotient(ref, self.rank)
        self.reference_ = ref
        self.Y_ref_ = Y
        return self

    def transform(self, X, verbose=False):
        n_mat, n, _ = X.shape
        output = np.zeros((n_mat, n * self.rank))
        for j, C in enumerate(X):
            if verbose:
                print('\r %d / %d' % (j+1, n_mat), end='', flush=True)
            Y = to_quotient(C, self.rank)
            output[j] = logarithm_(Y, self.Y_ref_).ravel()
        return output


def to_quotient(C, rank):
    d, U = np.linalg.eigh(C)
    U = U[:, -rank:]
    d = d[-rank:]
    Y = U * np.sqrt(d)
    return Y


def distance2(S1, S2, rank=None):
    Sq = sqrtm(S1, rank)
    P = sqrtm(np.dot(Sq, np.dot(S2, Sq)), rank)
    return np.trace(S1) + np.trace(S2) - 2 * np.trace(P)


def sqrtm(C, rank=None):
    if rank is None:
        rank = C.shape[0]
    d, U = np.linalg.eigh(C)
    U = U[:, -rank:]
    d = d[-rank:]
    return np.dot(U, np.sqrt(np.abs(d))[:, None] * U.T)


def logarithm_(Y, Y_ref):
    prod = np.dot(Y_ref.T, Y)
    U, D, V = np.linalg.svd(prod, full_matrices=False)
    Q = np.dot(U, V).T
    return np.dot(Y, Q) - Y_ref
