from sklearn.neighbors import NearestNeighbors
import numpy as np

""" KL-Divergence estimation through K-Nearest Neighbours

    This module provides four implementations of the K-NN divergence estimator of
        Qing Wang, Sanjeev R. Kulkarni, and Sergio Verdú.
        "Divergence estimation for multidimensional densities via
        k-nearest-neighbor distances." Information Theory, IEEE Transactions on
        55.5 (2009): 2392-2405.

    Code partially from https://github.com/nhartland/KL-divergence-estimators
"""
import warnings

import numpy as np
from scipy.spatial import KDTree
from sklearn.neighbors import NearestNeighbors

import faiss
class FaissNearestNeighbors:
    def __init__(self, n_neighbors=5):
        self.index = None
        self.k = n_neighbors

    def fit(self, X):
        self.index = faiss.IndexFlatL2(X.shape[1])
        self.index.add(X.astype(np.float32))
        return self

    def kneighbors(self, X, k):
        distances, indices = self.index.search(X.astype(np.float32), k=k)
        return np.sqrt(distances), indices


def skl_estimator(s1, s2, k=1, error=False):
        """ KL-Divergence estimator using scikit-learn's NearestNeighbours
            s1: (N_1,D) Sample drawn from distribution P
            s2: (N_2,D) Sample drawn from distribution Q
            k: Number of neighbours considered (default 1)
            return: estimated D(P|Q)
        """

        n, m = len(s1), len(s2)
        d = float(s1.shape[1])

        s1_neighbourhood = FaissNearestNeighbors(n_neighbors=k + 1).fit(s1)
        s2_neighbourhood = FaissNearestNeighbors(n_neighbors=k).fit(s2)

        # s1_neighbourhood = NearestNeighbors(n_neighbors=k + 1, algorithm='kd_tree').fit(s1)
        # s2_neighbourhood = NearestNeighbors(n_neighbors=k, algorithm='kd_tree').fit(s2)

        s1_distances, indices = s1_neighbourhood.kneighbors(s1, k + 1)
        s2_distances, indices = s2_neighbourhood.kneighbors(s1, k)
        rho = s1_distances[:, -1]
        nu = s2_distances[:, -1]
        if np.any(rho == 0):
            if error:
                raise ValueError(
                f"The distance between an element of the first dataset and its {k}-th NN in the same dataset "
                f"is 0; this causes divergences in the code, and it is due to elements which are repeated "
                f"{k + 1} times in the first dataset. Increasing the value of k usually solves this.")
            else:
                warnings.warn(
                    f"The distance between an element of the first dataset and its {k}-th NN in the same dataset "
                    f"is 0; this causes divergences in the code, and it is due to elements which are repeated "
                    f"{k + 1} times in the first dataset. Increasing the value of k usually solves this.",
                    RuntimeWarning)
        D = np.sum(np.log(nu / rho))

        return (d / n) * D + np.log(m / (n - 1))  # this second term should be enough for it to be valid for m \neq n
