import numpy as np

from sklearn_extra.robust import RobustWeightedKMeans
from sklearn.cluster import KMeans
from sklearn.cluster import DBSCAN
import pandas as pd

from pathos.multiprocessing import ProcessingPool as Pool
from multiprocessing import cpu_count

import load_synthetic_data as load_data
import ldme
import rme
import eval_performance as eval_perf

### Comparison Algorithms

def robust_kmeans_A0(samples, alpha_, n_clusters=100):
    kmeans = RobustWeightedKMeans(n_clusters=n_clusters, max_iter=1000, weighting='mom')
    print(kmeans)
    kmeans.fit(samples)
    return kmeans.cluster_centers_, np.full(n_clusters, alpha_)

def kmean_A0(samples, alpha_, n_clusters=100):
    kmeans = KMeans(n_clusters=n_clusters, n_init=1)
    kmeans.fit(samples)
    return kmeans.cluster_centers_, np.full(n_clusters, alpha_)

def dbscan_A0(samples, eps=1, min_samples=5):
    kmeans = DBSCAN(eps=eps, min_samples=min_samples)
    kmeans.fit(samples)
    cluster_means = []
    for label in np.unique(kmeans.labels_):
        if label != -1:  # excluding noise points
            cluster_means.append(samples[kmeans.labels_ == label].mean(axis=0).tolist())
    return kmeans.components_, kmeans.labels_, np.asarray(cluster_means)

def dbscan_A0_full(samples, alpha_, eps=1, min_samples=5):
    comp_samples, labels, means = dbscan_A0(samples, eps=eps, min_samples=min_samples)
    n_clusters = len(means)
    return means, np.full(n_clusters, alpha_)

### Inner and Outer Stage Algorithms

def gaussian_multifilter(samples, alpha_, parameters):
    d_factor, r_factor, c_factor, naive_cluster_factor, n_dir_factor, k_factor = parameters
    means = ldme.fast_gaussian_multifilter(samples, alpha_, r_factor, d_factor, c_factor, n_dir_factor, naive_cluster_factor, k_factor)
    return means, np.full(len(means), alpha_)

def bounded_multifilter(samples, alpha_, parameters):
    d_factor, r_factor, c_factor, naive_cluster_factor, n_dir_factor, gamma_factor_inner, delta_dist_factor, runs_factor = parameters
    n, d = samples.shape
    delta = (1 / (d**2))
    beta = 1 / np.log(d) # (0, 1)   
    means = ldme.fast_multifilter(samples, alpha_, delta, beta, r_factor, n_dir_factor, gamma_factor_inner, d_factor, delta_dist_factor, c_factor, runs_factor, naive_cluster_factor)
    return means, np.full(len(means), alpha_)

def first_stage_pruning(samples, all_centers, all_alphas, constant):
    sorted_indices = np.argsort(all_alphas)[::-1]
    sorted_centers = all_centers[sorted_indices]
    sorted_alphas = all_alphas[sorted_indices]

    back_counter = 0
    end_alg = False
    J = []
    vij = None
    while not end_alg:
        J = []
        T={}
        rerun=False
        back_counter = 0
        for idx in range(len(sorted_alphas)):
            if rerun:
                break
            idx -= back_counter
            too_close = np.any(np.linalg.norm(sorted_centers[J] - sorted_centers[idx], axis=1) <= 2*constant* np.sqrt(np.log(1. / sorted_alphas[idx])) +4)
            #too_close = np.any(np.linalg.norm(sorted_centers[J] - sorted_centers[idx], axis=1) <= 2*constant* np.sqrt(np.log(1. / sorted_alphas[idx])))
            if too_close:
                continue

            T_i = samples.copy()

            for jdx in J:
                vij = (sorted_centers[idx] - sorted_centers[jdx]) / np.linalg.norm(sorted_centers[idx] - sorted_centers[jdx])
                T_i = T_i[np.abs(np.matmul(T_i - sorted_centers[idx], vij)) <= constant * np.sqrt(np.log(1. / sorted_alphas[idx]))]

            if len(T_i) < sorted_alphas[idx] * len(samples)/4:
                sorted_centers = np.delete(sorted_centers, idx, axis=0)
                sorted_alphas = np.delete(sorted_alphas, idx)
                back_counter += 1
                if idx < len(sorted_centers) :
                    continue
                else:
                    end_alg = True
                    break
                    
            J.append(idx)
            T[idx] = T_i
            
            for jdx in J:
                T_j = T[jdx].copy()
                if jdx == idx:
                    continue
                vij = (sorted_centers[idx] - sorted_centers[jdx]) / np.linalg.norm(sorted_centers[idx] - sorted_centers[jdx])
                T_j = T_j[np.abs(np.dot(T_j - sorted_centers[jdx], vij)) <=constant* np.sqrt(np.log(1. / sorted_alphas[idx]))]

                if len(T_j) < sorted_alphas[jdx] * len(samples)/4:
                    sorted_centers = np.delete(sorted_centers, jdx, axis=0)
                    sorted_alphas = np.delete(sorted_alphas, jdx)
                    rerun=True
                    break
                    
        if not rerun:
            end_alg = True

    return sorted_centers[J], sorted_alphas[J]

def first_stage_worker_function(A0, alpha_, samples, A0_parameter):
    centers, alphas = A0(samples, alpha_, A0_parameter)
    if len(centers) < 1:
        return None, None
    return centers, alphas

def first_stage(A0, samples, alpha_min, A0_parameter=50, constant=0.5, robust=False, multiprocessing=False):
    t = alpha_min#**2.
    
    #result_list = [i for i in range(0, int((1. - alpha_min) / t))]
    result_list = [i for i in range(0, int((1. - alpha_min) / (2*t)))] # testing purposes: reduce the number of iterations
    G = np.array([alpha_min + i * t for i in result_list])
    all_centers = np.empty((0, samples.shape[1]))
    all_alphas = np.empty(0)

    if (multiprocessing):
        with Pool(cpu_count()) as pool:
            results = pool.map(lambda alpha_: first_stage_worker_function(A0, alpha_, samples, A0_parameter), G)
            flat_results = [result for result in results if result[0] is not None]
            if results:
                all_centers = np.vstack([result[0] for result in flat_results])
                all_alphas = np.concatenate([result[1] for result in flat_results])
    else:
        for i, alpha_ in enumerate(G):
            centers, alphas  = A0(samples, alpha_, A0_parameter)
            if len(centers) <1:
                break
            all_centers = np.vstack((all_centers, centers))
            all_alphas = np.concatenate((all_alphas, alphas))
    print("\nInitial list size", len(all_centers))
    if len(all_centers) >0:
        pruned_means, estimated_alphas = first_stage_pruning(samples, all_centers, all_alphas, constant=constant)
        print("Pruned list size", len(pruned_means))
        # robust mixture learning
        if(robust):
            # replace the means with robust means
            robust_means = []
            for mean, alpha in zip(pruned_means, estimated_alphas):
                pruned_mean = rme.robust_mean_estimation(samples, mean, 0.5, alpha_min, 1-alpha, 1)
                robust_means.append(pruned_mean)
            return robust_means, all_centers, estimated_alphas
        return pruned_means, all_centers, estimated_alphas
    else:
        #print("no means found")
        return [], [], []
    
def outer_stage(A0, samples, alpha_min, constant, A0_parameter=50, gamma_factor=0.1, robust=False, multiprocessing=False):
    # Run first stage on S with alpha_min
    GAMMA = 4 * constant * np.sqrt(np.log(1 / alpha_min))
    strong_bound_threshold = gamma_factor * GAMMA
    weak_bound_threshold = 3 * strong_bound_threshold
    print(f"Bound thresholds: {strong_bound_threshold}, {weak_bound_threshold}, alpha_min: {alpha_min}...")
    M, _, _ = first_stage(A0, samples, alpha_min, A0_parameter, constant, robust, multiprocessing)
    print(f"Initial means found: {len(M)}")
    n = len(samples)

    R = [xx for xx in range(len(M))]
    
    L = list()
    while len(R) > 0:
        sets = {i: [np.empty((0, samples.shape[1])), np.empty((0, samples.shape[1]))] for i in R}

        for i in R:
            set1, set2 = samples.copy(), samples.copy()
            mean = M[i]

            for j in range(len(M)):
                if j == i:
                    continue
                vij = (mean - M[j]) / np.linalg.norm(mean - M[j])
                proj1 = np.abs(np.matmul(set1 - mean, vij))
                proj2 = np.abs(np.matmul(set2 - mean, vij))
    
                set1 = set1[np.abs(proj1) <= strong_bound_threshold]
                set2 = set2[np.abs(proj2) <= weak_bound_threshold]

            sets[i][0] = set1
            sets[i][1] = set2
        R = [xx for xx in R if len(sets[xx][0] > 100 * (alpha_min**4) * n)] # remove from R if size too small
              
        if len(R) == 0:
            break
        
        # find the largest set which satisfies |S_i(2)| <= 2 * |S_i(1)|
        max_size = 0
        max_i = None
        for i in R:
            if (len(sets[i][1]) <= 2 * len(sets[i][0])) and (len(sets[i][0]) > max_size):
                max_size = len(sets[i][0])
                max_i = i

        if max_i is None:
            alpha = alpha_min * (n / len(samples))
            means, _, _ = first_stage(A0, samples, alpha, A0_parameter, constant, robust)
            L.extend(means)
            break
        else:
            alpha = alpha_min * (n / len(sets[max_i][1]))
            means, _, _ = first_stage(A0, sets[max_i][1], alpha, A0_parameter, constant, robust)
            L.extend(means)
            removal_mask = ~np.isin(samples, sets[max_i][0]).all(axis=1)
            samples = samples[removal_mask]
    print(f"Found {len(L)} means.")
    return L
