import numpy as np
from scipy.special import gamma, digamma, polygamma, gammaln
from util import compute_nig_kl

# import tensorflow_probability as tfp
# tfd = tfp.distributions
#
# class NormalInverseGamma(tfd.JointDistributionSequentialAutoBatched):
#     def __init__(self, loc, cov, conc, scale):
#         self.loc = loc
#         self.cov = cov
#         self.conc = conc
#         self.scale = scale
#
#         super(NormalInverseGamma, self).__init__([
#             tfd.InverseGamma(self.conc, self.scale, name="sigma"),
#             lambda sigma: tfd.MultivariateNormalFullCovariance(loc=self.loc,
#                                                                covariance_matrix = sigma * self.cov)
#             ])
#
#         self._parameters = dict(loc=loc, cov=cov, conc=conc, scale=scale)


def nig_fit(samples, prior_params):
    """
    Estimating inv gamma MLE: https://arxiv.org/pdf/1605.01019.pdf (See link for other methods, e.g. MAP)
    samples = (numpy array of sigma, numpy array of theta)
    """
    sigmas = samples[0]
    n = len(sigmas)
    su =sigmas.mean()
    sv = np.sum((sigmas - su)**2) / (n-1)
    a = su**2/sv + 2
    C = -np.log(np.sum(1./sigmas)) - np.mean(np.log(sigmas))

    while True:
        old_a = a
        inv_a = 1./a + (C - digamma(a) + np.log(n*a)) / (a**2 * (1./a - polygamma(1, a)))
        a = 1./inv_a
        if np.abs(a - old_a) < 1e-6:
            break
    a_n = a
    b_n = n * a_n / np.sum(1/sigmas)
    # print("inv gamma", a_n, b_n)

    thetas = samples[1]
    mu_n = thetas.mean(0)
    diff = thetas - mu_n


    squared_diff = diff.T @ (diff / sigmas.reshape(-1,1))
    inv_lambda_n = 1./n * squared_diff
    lambda_n = np.linalg.inv(inv_lambda_n)

    # mu_0, lambda_0, a_0, b_0 = prior_params

    # mu_n = mu_n.reshape(-1, 1)
    # s_xy = lambda_n @ mu_n - lambda_0 @ mu_0
    # s_yy = np.squeeze(2*(b_n - b_0) - mu_0.T.dot(lambda_0).dot(mu_0) + mu_n.T.dot(lambda_n).dot(mu_n))
    #
    # def recompute_kl(mu_n, lambda_n, a_n, b_n, inv_lambda_n, k=1):
    #     if k != 1:
    #         a_n = a_0 + k * (a_n - a_0)
    #         lambda_n = lambda_0 + k * (lambda_n - lambda_0)
    #         inv_lambda_n =  np.linalg.inv(lambda_n)
    #         mu_n = inv_lambda_n.dot(k*s_xy + lambda_0.dot(mu_0))
    #         b_n = b_0 + .5 * (k*s_yy + mu_0.T.dot(lambda_0).dot(mu_0) - mu_n.T.dot(lambda_n).dot(mu_n))
    #         b_n = np.squeeze(b_n)
    #
    #
    #     gamma_kl = a_0 * np.log(b_n/b_0) - (gammaln(a_n) - gammaln(a_0)) + (a_n-a_0) * digamma(a_n) - (b_n - b_0) * a_n / b_n
    #
    #     diff_mu = mu_0 - mu_n
    #     squared_diff_mu = np.squeeze(diff_mu.T @ lambda_0 @ diff_mu)
    #
    #     _, logdet_0 = np.linalg.slogdet(lambda_0)
    #     _, logdet_n = np.linalg.slogdet(lambda_n)
    #
    #     expect_normal_kl = a_n / b_n * squared_diff_mu + np.trace(lambda_0 @ inv_lambda_n) - logdet_0 + logdet_n - len(mu_0)
    #     expect_normal_kl = expect_normal_kl / 2
    #
    #     kl = gamma_kl + expect_normal_kl
    #     return kl
    #
    # kl = recompute_kl(mu_n, lambda_n, a_n, b_n, inv_lambda_n)
    # print(kl)
    post_params = [mu_n, lambda_n, a_n, b_n]
    kl = compute_nig_kl(prior_params, post_params)
    print(kl)

    # for k in np.arange(0.9, 0.2, -0.1):
    #     print("____recompute k", recompute_kl(mu_n, lambda_n, a_n, b_n, inv_lambda_n, k = 0.3))

    # nig0 =  NormalInverseGamma(loc=mu_0.flatten().astype(np.float64),
    #                            cov=np.linalg.inv(lambda_0).astype(np.float64),
    #                            conc=np.float64(a_0), scale=np.float64(b_0))
    #
    # nign =  NormalInverseGamma(loc=mu_n.flatten().astype(np.float64),
    #                            cov=inv_lambda_n.astype(np.float64),
    #                            conc=np.float64(a_n), scale=np.float64(b_n))
    #
    # mc_kl = nign.log_prob(samples).numpy() - nig0.log_prob(samples).numpy()
    # mc_kl = mc_kl.mean()
    # print(mc_kl)

    return kl, (prior_params, post_params)

def n_fit(samples, prior_params):
    """
    Estimating gaussian MLE
    mu_0, cov_0 = prior_params
    """
    n = len(samples)
    thetas = samples
    mu_n = thetas.mean(0)
    diff = thetas - mu_n

    squared_diff = diff.T @ diff
    inv_lambda_n = 1./n * squared_diff
    lambda_n = np.linalg.inv(inv_lambda_n)
    # print("precision", lambda_n)

    mu_0, cov_0 = prior_params
    lambda_0 = np.linalg.inv(cov_0)

    diff_mu = mu_0 - mu_n
    squared_diff_mu = diff_mu.T @ lambda_0 @ diff_mu

    _, logdet_0 = np.linalg.slogdet(lambda_0)
    _, logdet_n = np.linalg.slogdet(lambda_n)

    kl = squared_diff_mu + np.trace(lambda_0 @ inv_lambda_n) - logdet_0 + logdet_n - len(mu_0)
    kl = kl / 2

    print(kl)

    return kl
