import numpy as np

from sklearn_extra.robust import RobustWeightedKMeans
from sklearn.cluster import KMeans
from sklearn.cluster import DBSCAN
import pandas as pd

from pathos.multiprocessing import ProcessingPool as Pool
from multiprocessing import cpu_count

import load_synthetic_data as load_data
import ldme
import rme
import eval_performance as eval_perf

import robust_mixture_learning as rml

### Parallel processing functions

def worker_function_gaussian_multifilter(seed, S, true_centers, alpha_min, r_factor, d_factor, c_factor, n_dir_factor, naive_clusters_factor, k_factor):
    np.random.seed(seed)
    gaussian_means = ldme.fast_gaussian_multifilter(S, alpha_min, r_factor, d_factor, c_factor, n_dir_factor, naive_clusters_factor, k_factor)
    err = eval_perf.dist_metric(true_centers, gaussian_means)
    length = len(gaussian_means)
    return err, length

def worker_function_bounded_multifilter(seed, S, true_centers, alpha_min, r_factor, d_factor, c_factor, n_dir_factor, naive_clusters_factor, gamma_factor_inner, delta_dist_factor, runs_factor, delta, beta):
    np.random.seed(seed)
    bounded_means = ldme.fast_multifilter(S, alpha_min, delta, beta, r_factor, n_dir_factor, gamma_factor_inner, d_factor, delta_dist_factor, c_factor, runs_factor, naive_clusters_factor)
    err = eval_perf.dist_metric(true_centers, bounded_means)
    length = len(bounded_means)
    return err, length

def worker_function_outer_stage(A0, seed, S, true_centers, alpha_min, beta_constant, gamma_factor, A0_parameter, robust, multiprocessing_outer):
    np.random.seed(seed)
    rml_means = rml.outer_stage(A0=A0, samples=S, alpha_min=alpha_min, constant=beta_constant, gamma_factor=gamma_factor, A0_parameter=A0_parameter, robust=robust, multiprocessing=multiprocessing_outer)
    err = eval_perf.dist_metric(true_centers, rml_means)
    length = len(rml_means)
    return err, length

def worker_function_gauss_multifilter_without_outer(seed, S, true_centers, alpha_min, A0_parameter, beta_constant, robust, multiprocessing_outer):
    np.random.seed(seed)
    pruned_means, dbscan_centers,_ = rml.first_stage(A0=rml.gaussian_multifilter, samples=S, alpha_min=alpha_min, A0_parameter=A0_parameter, constant=beta_constant, multiprocessing=multiprocessing_outer, robust=robust)
    err = eval_perf.dist_metric(true_centers, pruned_means)
    length = len(pruned_means)
    return err, length


### Main function to run all algorithms
    
def get_clustering_advantage(seed, r,d=2,constant=0.5, noise_model='gaussian', dataset='mixture', robust=False, dist='gauss', multiprocessing_outer=False, multiprocessing_adv=False, gauss_multifilter=False, weight_min = 0.01, params=None, iterations_per_seed = 10):
    epsilon_add = 0.1
    epsilon = epsilon_add / (1 + epsilon_add)
    weights = [0.3, 0.2, 0.2, 0.1, 0.1, 0.05, 0.03, 0.02]
    alpha_min=min(weights)*(1-epsilon)*0.9

    np.random.seed(seed)

    ### Load data and noise model
    if(dataset == 'mixture'):
        num_clusters = len(weights)
        S, true_centers = load_data.generate_mixture_data_with_separated_centers(num_samples=10000, separation=40, num_clusters=num_clusters, 
                                                 weights=weights, d=d, dist=dist)
        S = load_data.load_noise_model(noise_model, S, true_centers, weights, r, epsilon_add, dist)
    elif(dataset == 'mnist'):
        S, true_centers = load_data.load_and_project_mnist_with_clustering(num_samples=1000, k=2, num_clusters=10)
        S = load_data.load_noise_model(noise_model, S, true_centers, weights, r, epsilon_add, dist)
    elif(dataset == 'genomic'):
        S, true_centers, weights, cluster_cov = load_data.load_genomic_dataset(weight_min=weight_min)
        alpha_min = min(weights)*(1-epsilon)*0.9 / np.sum(weights) # normalize to account for weight_min cut-off
        S = load_data.load_noise_model(noise_model, S, true_centers, weights, r, epsilon_add, dist, cluster_cov)

    print("First value of S to check seed: ", S[0])

    ### Run Robust Mixture-Learning algo with Gaussian Multifilter as baselearner (no outer stage)
    pruned_means_dist_arr = []
    pruned_means_list_size = []
    """print(">>> Run First Stage with Gaussian Multifilter <<<")
    d_factor = 0.2              
    r_factor = 0.9  #0.9              
    c_factor = 0.5 #0.5                
    n_dir_factor = 1.1            
    naive_clusters_factor = 1
    k_factor = 1
    beta_constant = 1 #1
    A0_parameter = (d_factor, r_factor, c_factor, naive_clusters_factor, n_dir_factor, k_factor)

    if multiprocessing_adv:
        with Pool(cpu_count()) as pool:
            results = pool.map(lambda seed: worker_function_gauss_multifilter_without_outer(seed, S, true_centers, alpha_min, A0_parameter, beta_constant, robust, multiprocessing_outer), range(iterations_per_seed))
        pruned_means_dist_arr = [result[0] for result in results]
        pruned_means_list_size = [result[1] for result in results]
    else:
        for _ in range(iterations_per_seed):
            pruned_means, _,_ = rml.first_stage(A0=rml.gaussian_multifilter, samples=S, alpha_min=alpha_min, A0_parameter=A0_parameter, constant=beta_constant, multiprocessing=multiprocessing_outer, robust=robust)
            if len(pruned_means) <1:
                continue
            pruned_means_dist_arr.append(eval_perf.dist_metric(true_centers, pruned_means))
            pruned_means_list_size.append(len(pruned_means))"""

    ### Run Gaussian Multifilter without Outer Stage
    gauss_dist_arr = []
    gauss_list_size_arr = []
    if(gauss_multifilter):
        print(">>> Run Gaussian Multifilter (LDME) <<<")
        d_factor = 0.2              
        r_factor = 0.9  #0.9              
        c_factor = 0.5 #0.5                
        n_dir_factor = 1.1            
        naive_clusters_factor = 1
        k_factor = 1

        if(dataset == 'genomic'):
            d_factor = 0.2
            r_factor = 0.01
            c_factor = 1
            k_factor = 1
            n_dir_factor = 1.1

        if params is not None:
            d_factor, r_factor, c_factor, k_factor, n_dir_factor, _, _ = params

        if multiprocessing_adv:
            with Pool(cpu_count()) as pool:
                results = pool.map(lambda seed: worker_function_gaussian_multifilter(seed, S, true_centers, alpha_min, r_factor, d_factor, c_factor, n_dir_factor, naive_clusters_factor, k_factor), range(iterations_per_seed))
            gauss_dist_arr = [result[0] for result in results]
            gauss_list_size_arr = [result[1] for result in results]
        else:
            for _ in range(iterations_per_seed):
                gaussian_means = ldme.fast_gaussian_multifilter(S, alpha_min, r_factor, d_factor, c_factor, n_dir_factor, naive_clusters_factor, k_factor)
                gauss_dist_arr.append(eval_perf.dist_metric(true_centers, gaussian_means))
                gauss_list_size_arr.append(len(gaussian_means))

    ### Run Robust Mixture-Learning with Outer and Inner Stage
    rml_dist_arr = []
    rml_list_size_arr = []
    if(gauss_multifilter):
        print(">>> Robust Mixture Learning (Ours) with Outer and Inner Stage <<<")
        # gauss optimized params
        d_factor = 0.2              
        r_factor = 0.9   # 0.9             
        c_factor = 0.5 #0.5                
        n_dir_factor = 1.1            
        naive_clusters_factor = 1
        k_factor = 1
        gamma_factor = 0.1 #0.1
        beta_constant = 1 #1

        if(dataset == 'genomic'):
            d_factor = 0.2
            r_factor = 0.01
            c_factor = 1
            k_factor = 1
            n_dir_factor = 1.1
            
            gamma_factor = 0.5
            beta_constant = 0.03

        if params is not None:
            d_factor, r_factor, c_factor, k_factor, n_dir_factor, gamma_factor, beta_constant = params

        A0_parameter = (d_factor, r_factor, c_factor, naive_clusters_factor, n_dir_factor, k_factor)
        if multiprocessing_adv:
            with Pool(cpu_count()) as pool:
                results = pool.map(lambda seed: worker_function_outer_stage(rml.gaussian_multifilter, seed, S, true_centers, alpha_min, beta_constant, gamma_factor, A0_parameter, robust, multiprocessing_outer), range(iterations_per_seed))
            rml_dist_arr = [result[0] for result in results]
            rml_list_size_arr = [result[1] for result in results]
        else:
            for _ in range(iterations_per_seed):
                rml_means = rml.outer_stage(A0=rml.gaussian_multifilter, samples=S, alpha_min=alpha_min, constant=beta_constant, gamma_factor=gamma_factor, A0_parameter=A0_parameter, robust=robust, multiprocessing=multiprocessing_outer)
                rml_dist_arr.append(eval_perf.dist_metric(true_centers, rml_means))
                rml_list_size_arr.append(len(rml_means))
    
    ### Run Bounded Multifilter without Outer Stage
    bounded_dist_arr = []
    bounded_list_size_arr = []
    
    """if not gauss_multifilter:
        print(">>> Run Bounded Multifilter without Outer Stage <<<")
        r_factor = 0.4  
        n_dir_factor = 0.1      
        gamma_factor_inner = 0.0001 
        d_factor = 1 
        delta_dist_factor = 0.02 
        c_factor = 5   
        runs_factor = 1/20   
        naive_cluster_factor = 1
        delta = (1 / (d**2))
        beta = 1 / np.log(d) 

        if multiprocessing_adv:
            with Pool(cpu_count()) as pool:
                results = pool.map(lambda seed: worker_function_bounded_multifilter(seed, S, true_centers, alpha_min, r_factor, d_factor, c_factor, n_dir_factor, naive_cluster_factor, gamma_factor_inner, delta_dist_factor, runs_factor, delta, beta), range(8))
            bounded_dist_arr = [result[0] for result in results]
            bounded_list_size_arr = [result[1] for result in results]
        else:
            for _ in range(8): 
                bounded_means = ldme.fast_multifilter(S, alpha_min, delta, beta, r_factor, n_dir_factor, gamma_factor_inner, d_factor, delta_dist_factor, c_factor, runs_factor, naive_cluster_factor)
                bounded_dist_arr.append(eval_perf.dist_metric(true_centers, bounded_means))
                bounded_list_size_arr.append(len(bounded_means))"""
            
    ### Run Full RML with Bounded Multifilter
    rml_bounded_dist_arr = []
    rml_bounded_list_size_arr = []

    """if not gauss_multifilter:
        print("Robust Mixture Learning with Bounded Multifilter ...")
        r_factor = 0.4  
        n_dir_factor = 0.1
        gamma_factor_inner = 0.0001 
        d_factor = 1 
        delta_dist_factor = 0.02 
        c_factor = 5   
        runs_factor = 1/20   
        naive_cluster_factor = 1
        beta_constant = 1
        gamma_factor_outer = 0.1
        A0_parameter=(d_factor, r_factor, c_factor, naive_cluster_factor, n_dir_factor, gamma_factor_inner, delta_dist_factor, runs_factor)

        if multiprocessing_adv:
            with Pool(cpu_count()) as pool:
                results = pool.map(lambda seed: worker_function_outer_stage(rml.bounded_multifilter, seed, S, true_centers, alpha_min, beta_constant, gamma_factor_outer, A0_parameter, robust, multiprocessing_outer), range(20))
            rml_bounded_dist_arr = [result[0] for result in results]
            rml_bounded_list_size_arr = [result[1] for result in results]
        else:
            for _ in range(20):
                rml_means = rml.outer_stage(A0=rml.bounded_multifilter, samples=S, alpha_min=alpha_min, constant=beta_constant, gamma_factor=gamma_factor_outer, A0_parameter=A0_parameter, robust=robust, multiprocessing=multiprocessing_outer)
                rml_bounded_dist_arr.append(eval_perf.dist_metric(true_centers, rml_means))
                rml_bounded_list_size_arr.append(len(rml_means))"""

    ### Run DBSCAN
    print(">>> Run DBSCAN <<<")
    dbscan_dist_arr = []
    dbscan_list_size = []
    eps_params = np.linspace(1, 50, iterations_per_seed)
    min_samples = 5
    if(dataset == 'genomic'):
        eps_params = np.linspace(0.05, 0.15, 50)
        min_samples = 3
    for eps_dbscan in eps_params:
        dbscan_centers, _ = rml.dbscan_A0_full(S, 0.1, eps_dbscan, min_samples)
        if len(dbscan_centers) <1:
            continue
        dbscan_dist_arr.append(eval_perf.dist_metric(true_centers, dbscan_centers))
        dbscan_list_size.append(len(dbscan_centers))

    ### Run Robust Mixture-Learning algo with dbscan and kmeans as baselearner (outer stage) - just for testing effect of outer stage
    rml_dbk_dist_arr = []
    rml_dbk_list_size_arr = []

    ### Run Kmeans
    print("Run Kmeans ...")
    dist_kmeans_arr = []
    #k_arr = np.linspace(5, 45, 20).astype(np.int32)
    k_arr = np.linspace(5, 40, 20).astype(np.int32)
    k_arr = np.repeat(k_arr, iterations_per_seed)
    if(dataset == 'genomic'):
        k_arr = np.linspace(2, 35, 20).astype(np.int32)
    for num_kmeans_centers in k_arr:
        kmeans_once_centers,_ = rml.kmean_A0(S, 0.1, num_kmeans_centers)
        kmeans_once_centers = np.asarray(kmeans_once_centers)
        dist_kmeans_arr.append(eval_perf.dist_metric(true_centers, kmeans_once_centers))
    
    ### Run Robust Kmeans
    print("Run Robust Kmeans ...")
    dist_rob_kmeans_arr = []
    #rob_k_arr = np.linspace(5, 45, 20).astype(np.int32)
    rob_k_arr = np.linspace(5, 40, 20).astype(np.int32)
    rob_k_arr = np.repeat(rob_k_arr, iterations_per_seed)
    if(dataset == 'genomic'):
        rob_k_arr = np.linspace(2, 35, 20).astype(np.int32)
    for num_kmeans_centers in rob_k_arr:
            # catch IndexError: Out of bounds on buffer access (axis 0)
        try:
            kmeans_once_centers,_ = rml.robust_kmeans_A0(S, 0.1, num_kmeans_centers)
            kmeans_once_centers = np.asarray(kmeans_once_centers)
            dist_rob_kmeans_arr.append(eval_perf.dist_metric(true_centers, kmeans_once_centers))
        except:
            print("Error due to stability issues of robust kmeans on this dataset for num_clusters: {}".format(num_kmeans_centers))
    
    return pruned_means_dist_arr, dist_kmeans_arr, pruned_means_list_size, k_arr, dbscan_list_size, dbscan_dist_arr, dist_rob_kmeans_arr, rml_dist_arr, rml_list_size_arr, rml_bounded_dist_arr, rml_bounded_list_size_arr, rml_dbk_dist_arr, rml_dbk_list_size_arr, gauss_dist_arr, gauss_list_size_arr, bounded_dist_arr, bounded_list_size_arr


### Function to collect exmperimental results for multiple runs

def run_mixture_learning_experiment(num_runs=3, r=10, d=100, pruning_constant=20, noise_model='gaussian', dataset='mixture', robust=False, dist='gauss', multiprocessing_outer=False, multiprocessing_adv=False, gauss_multifilter=False, weight_min=0.01, params=None, iterations_per_seed=10):
    # Initialize a DataFrame to store your data
    data = {'method': [], 'list_size': [], 'error': [], 'seed': []}

    # Collect data
    for seed in range(num_runs):
        print("Seed: ", seed+1)
        our_err, kmeans_err, list_size, k_arr, dbscan_list_size, dbscan_dist_arr, rob_kmeans_err, rml_err, rml_list_size, rml_bounded_err, rml_bounded_list_size, rml_dbk_dist_arr, rml_dbk_list_size_arr, gauss_dist_arr, gauss_list_size_arr, bounded_dist_arr, bounded_list_size_arr = get_clustering_advantage(seed, r, d, pruning_constant, noise_model, dataset, robust, dist, multiprocessing_outer, multiprocessing_adv, gauss_multifilter, weight_min, params, iterations_per_seed)

        # Add Kmeans data
        for ls, err in zip(k_arr, kmeans_err):
            data['method'].append('Kmeans')
            data['list_size'].append(ls)
            data['error'].append(err)
            data['seed'].append(seed)

        # Add Robust Kmeans data
        for ls, err in zip(k_arr, rob_kmeans_err):
            data['method'].append('Robust Kmeans')
            data['list_size'].append(ls)
            data['error'].append(err)
            data['seed'].append(seed)
        
        # Add DBScan data
        for err, ls in zip(dbscan_dist_arr, dbscan_list_size):
            data['method'].append('DBScan')
            data['list_size'].append(ls)
            data['error'].append(err)
            data['seed'].append(seed)

        # Add 'Ours' data (if it varies per seed, otherwise handle separately)
        for err, ls in zip(our_err, list_size):
            data['method'].append('LDME without Outer')
            data['list_size'].append(ls)
            data['error'].append(err)
            data['seed'].append(seed)

        # Add RML Gauss LDME data
        for err, ls in zip(rml_err, rml_list_size):
            data['method'].append('Full RML')
            data['list_size'].append(ls)
            data['error'].append(err)
            data['seed'].append(seed)

        # Add RML Bounded LDME data
        for err, ls in zip(rml_bounded_err, rml_bounded_list_size):
            data['method'].append('Bounded RML')
            data['list_size'].append(ls)
            data['error'].append(err)
            data['seed'].append(seed)

        # Add Gaussian Multifilter data
        for ls, err in zip(gauss_list_size_arr, gauss_dist_arr):
            data['method'].append('Gaussian Multifilter')
            data['list_size'].append(ls)
            data['error'].append(err)
            data['seed'].append(seed)

        # Add Bounded Multifilter data
        for ls, err in zip(bounded_list_size_arr, bounded_dist_arr):
            data['method'].append('Bounded Multifilter')
            data['list_size'].append(ls)
            data['error'].append(err)
            data['seed'].append(seed)

    df = pd.DataFrame(data)
    df.to_csv('{}_list.csv'.format(noise_model), index = False)
    return df
