import numpy as np
import scipy.stats
from clustering_utils import get_similarity
from itertools import product
from numpy.random import choice
import scipy.stats

from scipy.cluster.hierarchy import linkage
from scipy.cluster.hierarchy import fcluster
from sklearn.cluster import AgglomerativeClustering 

from copy import deepcopy


# def magnitude_gradient(gradients):
#     magnitudes = []
#     for idx in range(len(gradients)):
#         gradient = gradients[idx][1]
#         magnitudes.append(gradient)
#     return magnitudes

def magnitude_gradient(gradients):
    magnitudes = []
    for idx in range(len(gradients)):
        gradient = gradients[idx][0]
        m, n = gradient.shape
        magnitude = np.zeros(m)
        for c in range(m):
            magnitude[c] = np.sum(gradient[c])/n
        magnitudes.append(magnitude)
    return magnitudes

def estimated_entropy_from_grad(magnitudes, T):
    estimated_H = []
    num_clients = len(magnitudes)
    for idx in range(num_clients):
        magnitudes[idx] = np.exp(magnitudes[idx]/T)/ np.sum(np.exp(magnitudes[idx]/T))
        pk = np.array(magnitudes[idx])
        estimated_h = scipy.stats.entropy(pk)
        estimated_H.append(estimated_h)
    return estimated_H



def get_matrix_similarity_from_grads_entropy(local_model_grads, estimated_H,distance_type):
    """return the similarity matrix where the distance chosen to
    compare two clients is set with `distance_type`"""
    lam = 0.1
    n_clients = len(local_model_grads)

    metric_matrix = np.zeros((n_clients, n_clients))
    for i, j in product(range(n_clients), range(n_clients)):
        metric = get_similarity(local_model_grads[i], local_model_grads[j], distance_type) 
        metric_matrix[i, j] =  lam*metric + (1-lam)*abs(estimated_H[i] - estimated_H[j])
    return metric_matrix


def sample_clients_entropy(entropy, Clusters, n_samples, global_epoch, epoch):
    
    gamma = 4
    n_clients = len(n_samples)
    n_clustered = len(Clusters)
    entropy = np.exp(gamma*(global_epoch-epoch)*entropy/global_epoch)

    p_cluster = entropy/np.sum(entropy)
    print(p_cluster)
    sampled_clients = []
    clusters_selected = [0]*n_clustered
    for k in range(n_clustered):
        select_group = int(choice(n_clustered, 1, p=p_cluster)) 
        while clusters_selected[select_group] >= len(Clusters[select_group]):
            select_group = int(choice(n_clustered, 1, p=p_cluster)) 
        clusters_selected[select_group] += 1
        
    for k in range(len(clusters_selected)):
        if clusters_selected[k] == 0:
            continue
        select_clients = choice(Clusters[k], clusters_selected[k], replace = False)
        for i in range(clusters_selected[k]):
            sampled_clients.append(select_clients[i])
    
    return sampled_clients

def estimated_entropy(estimated_H, Clusters):
    Entropys = []
    for k in range(len(Clusters)):
        print("cluster ", k)
        print(Clusters[k])
        group_entropy = 0
        cluster = Clusters[k]
        for idx in cluster:
            group_entropy += estimated_H[idx]
            
        group_entropy = group_entropy/len(cluster)
        
        Entropys.append(group_entropy)
    Entropys = np.array(Entropys)
    return Entropys

def estimated_entropy_from_grad(magnitudes, T):
    estimated_H = []
    num_clients = len(magnitudes)
    for idx in range(num_clients):
        magnitudes[idx] = np.exp(magnitudes[idx]/T)/ np.sum(np.exp(magnitudes[idx]/T))
        pk = np.array(magnitudes[idx])
        estimated_h = scipy.stats.entropy(pk)
        estimated_H.append(estimated_h)
    return estimated_H

def HiCS_sampling(gradients,magnitudes,T, sim_type,  n_samples, n_sampled, global_epoch, epoch):
    Clusters = []
    estimated_H = estimated_entropy_from_grad(magnitudes,T)
    print(estimated_H)
    sim_matrix = get_matrix_similarity_from_grads_entropy(gradients, estimated_H, distance_type=sim_type)
    linkage_matrix = linkage(sim_matrix, "ward") 
    hc = AgglomerativeClustering(n_clusters = n_sampled, affinity = "euclidean", 
                             linkage = 'ward') 
 
    hc.fit_predict(sim_matrix)
    labels = hc.labels_
    for i in range(n_sampled):
        cluster_i = np.where(labels == i)[0]
        Clusters.append(cluster_i)    
    
    avg_entropy = estimated_entropy(estimated_H,Clusters)
    
    return sample_clients_entropy(avg_entropy, Clusters,n_samples, global_epoch, epoch)