import numpy as np
from numpy.linalg import inv
from scipy.special import gammaln, digamma
from scipy.linalg import solve_triangular, cho_solve, cho_factor

def cho_inverse(b):
    return cho_solve(cho_factor(b, lower=True), np.eye(b.shape[1]))


def fast_sample_multivariate_normal(mean, cov):
    try:
        decomposed = np.linalg.cholesky(cov)
        method =  "cholesky"
    except np.linalg.LinAlgError as e:
        decomposed = np.linalg.svd(cov)
        method = "SVD"

    if method == "cholesky":
        standard_normal_vector = np.random.standard_normal(len(decomposed))
        return decomposed @ standard_normal_vector + mean
    elif method == "SVD":
        u, s, vh = decomposed
        standard_normal_vector = np.random.standard_normal(len(u))
        return u @ np.diag(np.sqrt(s)) @ vh @ standard_normal_vector + mean

def compute_nig_kl(prior_params, post_params, k=1):
    mu_0, lambda_0, a_0, b_0 = prior_params
    mu_n, lambda_n, a_n, b_n = post_params
    inv_lambda_n = np.linalg.inv(lambda_n)

    mu_0 = mu_0.reshape(-1, 1)
    mu_n = mu_n.reshape(-1, 1)

    if k != 1:
        # get the sufficient statistics to scale
        s_xy = lambda_n @ mu_n - lambda_0 @ mu_0
        s_yy = np.squeeze(2*(b_n - b_0) - mu_0.T.dot(lambda_0).dot(mu_0) + mu_n.T.dot(lambda_n).dot(mu_n))

        a_n = a_0 + k * (a_n - a_0)
        lambda_n = lambda_0 + k * (lambda_n - lambda_0)
        inv_lambda_n =  np.linalg.inv(lambda_n)
        mu_n = inv_lambda_n.dot(k*s_xy + lambda_0.dot(mu_0))
        b_n = b_0 + .5 * (k*s_yy + mu_0.T.dot(lambda_0).dot(mu_0) - mu_n.T.dot(lambda_n).dot(mu_n))
        b_n = np.squeeze(b_n)

    gamma_kl = a_0 * np.log(b_n/b_0) - (gammaln(a_n) - gammaln(a_0)) + (a_n-a_0) * digamma(a_n) - (b_n - b_0) * a_n / b_n

    diff_mu = mu_0 - mu_n
    squared_diff_mu = np.squeeze(diff_mu.T @ lambda_0 @ diff_mu)

    _, logdet_0 = np.linalg.slogdet(lambda_0)
    _, logdet_n = np.linalg.slogdet(lambda_n)

    expect_normal_kl = a_n / b_n * squared_diff_mu + np.trace(lambda_0 @ inv_lambda_n) - logdet_0 + logdet_n - len(mu_0)
    expect_normal_kl = expect_normal_kl / 2

    kl = gamma_kl + expect_normal_kl
    return kl

def calc_posterior_params(S, N, model_prior_params, printkl=False):

    mu_0, lambda_0, a_0, b_0 = model_prior_params

    # S = project_suff_stats(S)

    # lambda_n = symmetrize(S['XX'] + lambda_0)
    # inv_lambda_n = symmetrize(np.linalg.inv(lambda_n))
    lambda_n = S['XX'] + lambda_0
    # inv_lambda_n = np.linalg.inv(lambda_n)
    inv_lambda_n = cho_inverse(lambda_n)

    # if not isPD(lambda_n):
    #     lambda_n = nearestPD(lambda_n)
    #
    # if not isPD(inv_lambda_n):
    #     inv_lambda_n = nearestPD(inv_lambda_n)

    mu_n = inv_lambda_n.dot(S['Xy'] + lambda_0.dot(mu_0))
    a_n = a_0 + .5 * N
    b_n = b_0 + .5 * (S['yy'] + mu_0.T.dot(lambda_0).dot(mu_0) - mu_n.T.dot(lambda_n).dot(mu_n))[0, 0]

    if printkl:
        # normal inverse gamma KL from https://statproofbook.github.io/P/ng-kl.html
        # according to normal inverse gamma, x | N (mu, sigma^2 / small gamma)
        # from wiki, kl of gamma = inv gamma: https://en.wikipedia.org/wiki/Inverse-gamma_distribution
        #
        gamma_kl = a_0 * np.log(b_n/b_0) - (gammaln(a_n) - gammaln(a_0)) + (a_n-a_0) * digamma(a_n) - (b_n - b_0) * a_n / b_n

        diff_mu = mu_0 - mu_n
        squared_diff_mu = diff_mu.T @ lambda_0 @ diff_mu

        _, logdet_0 = np.linalg.slogdet(lambda_0)
        _, logdet_n = np.linalg.slogdet(lambda_n)
        # expected(sigma^2) = b / (a-1)
        # expect_normal_kl =  b_n / (a_n - 1) * squared_diff_mu[0,0] + np.trace(lambda_0 @ inv_lambda_n) - logdet_0 + logdet_n - len(mu_0)
        # or should we be taking expectation over (1/sigma^2) = a/b instead?
        expect_normal_kl = a_n / b_n * squared_diff_mu[0,0] + np.trace(lambda_0 @ inv_lambda_n) - logdet_0 + logdet_n - len(mu_0)
        expect_normal_kl = expect_normal_kl / 2
        print("True KL", round(gamma_kl + expect_normal_kl, 6), "=InvGamma KL", round(gamma_kl, 6), " + expected normal KL", round(expect_normal_kl,6))

    # b_n = max(b_n, .1) # TODO put back in?

    return mu_n, lambda_n, inv_lambda_n, a_n, b_n


def product_of_two_multivariate_normals(mean1, cov1, mean2, cov2):

    # https://math.stackexchange.com/questions/157172/product-of-two-multivariate-gaussians-distributions
    # cholesky
    try:
        chol = np.linalg.cholesky(cov1 + cov2)
        # L_inv = inv(chol)
        L_inv = solve_triangular(chol, np.eye(len(mean1)), lower=True, check_finite=False)

        cov1_tilde = L_inv.dot(cov1)
        cov2_tilde = L_inv.dot(cov2)
        combined_covariance = cov1_tilde.T.dot(cov2_tilde)
        combined_mean = cov2_tilde.T.dot(L_inv.dot(mean1)) + cov1_tilde.T.dot(L_inv.dot(mean2))

    # cov1+cov2 isn't PSD
    except np.linalg.LinAlgError:
        temp = inv(cov1 + cov2)
        combined_covariance = cov1.dot(temp).dot(cov2)
        combined_mean = cov2.dot(temp).dot(mean1) \
                        + cov1.dot(temp).dot(mean2)
    #
    # # alternative method with direct inverses of individual cov matrices, not as accurate
    # if False:
    #     combined_covariance = inv(inv(cov1) + inv(cov2))
    #     combined_mean = combined_covariance.dot(inv(cov1).dot(mean1)
    #                                                   + inv(cov2).dot(mean2))

    return combined_mean, combined_covariance


def project_suff_stats(S):
    S = S.copy()

    AA = S['XX'].copy()
    AA = np.hstack((AA, S['Xy']))
    AA = np.vstack((AA, np.hstack((S['Xy'].T, [[S['yy']]]))))

    if not isPD(AA):
        AA = nearestPD(AA)

    S['XX'] = AA[:-1, :-1]
    S['X'] = S['XX'][:, -1][:, None]
    S['Xy'] = AA[-1, :-1][:, None]
    S['yy'] = AA[-1, -1]

    return S


def nearestPD_iterative(A):
    """Find the nearest positive-definite matrix to input

    A Python/Numpy port of John D'Errico's `nearestSPD` MATLAB code [1], which
    credits [2].

    [1] https://www.mathworks.com/matlabcentral/fileexchange/42885-nearestspd

    [2] N.J. Higham, "Computing a nearest symmetric positive semidefinite
    matrix" (1988): https://doi.org/10.1016/0024-3795(88)90223-6
    """

    B = (A + A.T) / 2
    _, s, V = np.linalg.svd(B)

    H = np.dot(V.T, np.dot(np.diag(s), V))

    A2 = (B + H) / 2

    A3 = (A2 + A2.T) / 2

    if isPD(A3):
        return A3

    spacing = np.spacing(np.linalg.norm(A))
    # The above is different from [1]. It appears that MATLAB's `chol` Cholesky
    # decomposition will accept matrixes with exactly 0-eigenvalue, whereas
    # Numpy's will not. So where [1] uses `eps(mineig)` (where `eps` is Matlab
    # for `np.spacing`), we use the above definition. CAVEAT: our `spacing`
    # will be much larger than [1]'s `eps(mineig)`, since `mineig` is usually on
    # the order of 1e-16, and `eps(1e-16)` is on the order of 1e-34, whereas
    # `spacing` will, for Gaussian random matrixes of small dimension, be on
    # othe order of 1e-16. In practice, both ways converge, as the unit test
    # below suggests.
    I = np.eye(A.shape[0])
    k = 1
    while not isPD(A3):
        mineig = np.min(np.real(np.linalg.eigvals(A3)))
        A3 += I * (-mineig * k**2 + spacing)
        k += 1

    return A3


# https://math.stackexchange.com/questions/2776803/matrix-projection-onto-positive-semi-definite-cone-with-respect-to-the-spectral
def nearestPD(B):

    B = B.copy()

    flip = False
    if len(B.shape) == 4:
        d = B.shape[0]
        B = B.reshape((d ** 2, d ** 2))
        flip = True

    if True:
        w, v = np.linalg.eig(B)
        w = np.real(w)
        v = np.real(v)
        # w[w < 1e-3] = 1e-3 # np.finfo(float).eps
        w[w < 0] = 0
        B = v.dot(np.diag(w)).dot(v.T)

        # B = symmetrize(B)

        min_eig = np.min(np.real(np.linalg.eigvals(B)))
        if min_eig < 0:
            B -= 10 * min_eig * np.eye(*B.shape)

    else:
        u, s, vh = np.linalg.svd(B)

        # u_new = (u + vh) / 2
        u_new = u
        B = u_new.dot(np.diag(s)).dot(u_new.T)

    # B = symmetrize(B)

    if flip:
        B = B.reshape((d, d, d, d))

    return B


def symmetrize(B):

    B = B.copy()

    flip = False
    if len(B.shape) == 4:
        d = B.shape[0]
        B = B.reshape((d ** 2, d ** 2))
        flip = True

    B = np.triu(B) + np.triu(B, k=1).T

    if flip:
        B = B.reshape((d, d, d, d))

    return B

def isPD(B, allowance=1e-14):
    """Returns true when input is positive-definite, via Cholesky"""

    B = B.copy()

    if len(B.shape) == 4:
        d = B.shape[0]
        B = B.reshape((d ** 2, d ** 2))

    try:
        B.ravel()[::B.shape[1]+1] += allowance
        np.linalg.cholesky(B)  #+ allowance * np.eye(B.shape[0]))
        return True
        np.linalg.svd(B)
    except np.linalg.LinAlgError:
        return False

    try:
        np.random.multivariate_normal(mean=np.zeros(B.shape[0]), cov=B, check_valid='raise')
    except ValueError:
        return False

    return True


def symmetric_flatten_indices(d, lower=True):
    indices = []
    for ind1 in range(d):
        for ind2 in range(ind1+1, d):
            if lower:
                indices.append(d*ind2 + ind1)
            else:
                indices.append(d*ind1 + ind2)
    return indices

def generate_symmetric_from_triu(d, flattened):
    X = np.zeros((d,d))
    X[np.triu_indices(d)] = flattened
    return X + X.T - np.diag(np.diag(X))
