import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import recall_score
from sklearn.metrics import precision_score
from sklearn.metrics import f1_score
from sklearn.metrics.cluster import adjusted_rand_score
from sklearn.metrics.cluster import normalized_mutual_info_score
from sklearn.metrics.cluster import homogeneity_score
from sklearn.metrics.cluster import completeness_score
from sklearn.metrics.cluster import v_measure_score

from scipy import stats
from scipy.stats import t
from math import sqrt
from statistics import stdev
from statsmodels.stats.contingency_tables import mcnemar

import itertools
import pickle
import math
import copy
import os
import re
import networkx as nx

import matplotlib.pyplot as plt


# Path to save the grid search BIC results
def GenerateRangeStr(input_list):
    return str(input_list[0]) + '-' + str(input_list[-1])

# Check the consistency of the clustering results for a given sequence
def ClusterConsistencyofSeq(cluster_lab):
    m_lab = max(set(cluster_lab), key = cluster_lab.count)
    m_perc = cluster_lab.count(m_lab) / len(cluster_lab)
    return m_lab, m_perc


# Get the dictionary for index of different labels
def GetUniqueLabelDict(all_label):
    unique_label = set(all_label)
    label_dict = {l: [idx for idx, element in enumerate(all_label) if element == l] for l in unique_label}
    return label_dict


# Check the accuracy of the clustering results
def CheckClusterPurity(pred_label, true_label, true_label_dict, verbo=False):
    # Get the indexes of each unique label
    pred_dict = GetUniqueLabelDict(pred_label)
    true_dict = GetUniqueLabelDict(true_label)

    # Purity for each predicted cluster
    # purity_list: purity for each cluster
    # clLab_list: ground-truth label for each cluster
    # pred_count: data point number in each cluster
    purity_list, clLab_list, pred_count = [], [], []
    pred_labs = sorted(pred_dict.keys())
    pred2true_dict = {}
    for tmp_lab in pred_labs:
        # Display the size of current predicted cluster
        if verbo: print('# of Predicted Cluster ' + str(tmp_lab) + ': ', pred_label.count(tmp_lab))

        # Get the ground-truth for the points in current cluster
        tmp_gt = [true_label[i] for i in pred_dict[tmp_lab]]
        cl, count = np.unique(tmp_gt, return_counts=True)
        cl_lab = max(tmp_gt, key=tmp_gt.count)
        clLab_list.append(cl_lab)
        pred2true_dict[tmp_lab] = cl_lab
        if verbo: print('Pred cluster: ', true_label_dict[cl_lab])

        # Find the ground-truth label with the maximal number as the label for current cluster
        if len(count) > 0:
            tmp_purity = np.max(count) / len(tmp_gt)
        else:
            tmp_purity = 0
        purity_list.append(tmp_purity)
        pred_count.append(len(tmp_gt))

        # Display the purity of current cluster
        if verbo: print('Purity: ', tmp_purity, '\n--------------------')

    # Overall purity
    # Check the number of point index in each cluster falling into different true labels
    # Then select the maximum one
    correct_num = 0
    for check_key in pred_dict.keys():
        correct_num += max([len(set(pred_dict[check_key]).intersection(set(true_dict[target_key])))
                            for target_key in true_dict.keys()])
    purity = correct_num / len(pred_label)
    if verbo: print('Overall purity: ', purity)

    # Purity for each ground-truth label
    true_labs = sorted(list(np.unique(clLab_list)))
    class_purity_list = []
    for lab in true_labs:
        # Find the cluster index with the lab
        tmp_idx = [idx for idx in range(len(clLab_list)) if clLab_list[idx] == lab]
        # tmp_purity: purity across the lab clusters, initialized as 0
        # tmp_point_num: total number of points across the lab clusters
        tmp_purity, tmp_point_num = 0, sum([pred_count[idx] for idx in tmp_idx])
        for idx in tmp_idx:
            tmp_purity += purity_list[idx] * pred_count[idx] / tmp_point_num
        class_purity_list.append(tmp_purity)
        if verbo: print('Purity of cluster {}: {}'.format(true_label_dict[int(lab)], tmp_purity))

    return [purity, class_purity_list, purity_list], pred2true_dict


def ModelEvaluate(real_lab, pre_lab, verbo=False):
    conf_matrix = confusion_matrix(real_lab, pre_lab)
    acc = accuracy_score(real_lab, pre_lab)
    recall = precision_score(real_lab, pre_lab)
    precision = recall_score(real_lab, pre_lab)
    fscore = f1_score(real_lab, pre_lab)

    if verbo:
        print('Performance Measurements:')
        print('Confusion matrix: \n', conf_matrix)
        print('Accuracy: ', acc)
        print('Recall: ', recall)
        print('Precision: ', precision)
        print('F-score: ', fscore)
    return conf_matrix, [acc, recall, precision, fscore]


def ModelEvaluate_orig(real_lab, pre_lab, labels_list, avg_patterns='weighted', verbo=False): # 'weighted'/'micro'/‘macro’
    conf_matrix = confusion_matrix(real_lab, pre_lab, labels=labels_list)
    acc = accuracy_score(real_lab, pre_lab)
    recall = recall_score(real_lab, pre_lab, labels=labels_list, average=avg_patterns)
    precision = precision_score(real_lab, pre_lab, labels=labels_list, average=avg_patterns)
    fscore = f1_score(real_lab, pre_lab, labels=labels_list, average=avg_patterns)

    if verbo:
        print('Performance Measurements:')
        print('Confusion matrix: \n', conf_matrix)
        print('Accuracy: ', acc)
        print('Recall: ', recall)
        print('Precision: ', precision)
        print('F-score: ', fscore)
    return conf_matrix, [acc, recall, precision, fscore]

def DisplayEvaluationResults(metrics):
    print('Confusion matrix: ')
    print(metrics[0])

    print('Accuracy: ', metrics[1][0])
    print('Recall: ', metrics[1][1])
    print('Precision: ', metrics[1][2])
    print('F-score: ', metrics[1][3])


# Evaluate the testing results
def ModelTestResult(te_real_lab, te_pred_labels, true_label_dict,
                    tolerance=True, skip_num=3, skip_pattern='outside', verbo=False):
    ## Check the results with tolerance for boundaries
    all_idx = list(np.arange(len(te_real_lab)))
    if tolerance:
        # Check the size of each block
        end_idx = list(np.cumsum([sum(1 for i in g) for k, g in itertools.groupby(te_real_lab)]))
        start_idx = [0] + end_idx[:-1]

        eval_idx = list(np.arange(skip_num))
        for tmp_idx in range(len(start_idx)):
            eval_idx.extend(list(np.arange(start_idx[tmp_idx] + skip_num, end_idx[tmp_idx] - skip_num)))
        eval_idx.extend(all_idx[-skip_num:])

        if skip_pattern == 'outside':
            eval_idx = eval_idx
        elif skip_pattern == 'inside':
            eval_idx = sorted(list(set(np.arange(end_idx[-1])) - set(eval_idx)))
    else:
        eval_idx = all_idx

    real_lab = [te_real_lab[i] for i in eval_idx]
    pred_lab = [te_pred_labels[i] for i in eval_idx]
    te_metrics = ModelEvaluate_orig(real_lab, pred_lab, list(set(real_lab))) # weighted/macro

    # purity_all, _ = CheckClusterPurity(pred_lab, real_lab, true_label_dict, verbo=verbo)
    NMI_val = normalized_mutual_info_score(real_lab, pred_lab)
    ARI_val = adjusted_rand_score(real_lab, pred_lab)
    h_val = homogeneity_score(real_lab, pred_lab)
    c_val = completeness_score(real_lab, pred_lab)
    v_val = v_measure_score(real_lab, pred_lab)
    # te_metrics[1].append(purity_all[0])
    te_metrics[1].append(NMI_val)
    te_metrics[1].append(ARI_val)
    te_metrics[1].append(h_val)
    te_metrics[1].append(c_val)
    te_metrics[1].append(v_val)

    return te_metrics, eval_idx


# ----------------------------------------------------------------------------------------------------------
# Evaluate the clustering results by BIC, purity, and other metrics
def ClusterEvaluation(TICC_return, true_label, idx_ToCompare, true_label_dict, verbo=False):
    _, pred_list, _, BIC, _, _, _ = TICC_return
    # Get the Bayesian Inference Criteria
    print('BIC: ', BIC)

    # Calculate the purity of clustering results
    all_pred_label = list(itertools.chain.from_iterable(pred_list))
    pred_label = [all_pred_label[idx] for idx in idx_ToCompare]
    comp_true_label = [true_label[idx] for idx in idx_ToCompare]
    purity_all, pred2true_dict = CheckClusterPurity(pred_label, comp_true_label, true_label_dict, verbo=verbo)
    print('Purity: ', purity_all[0])

    # Calculate the metrics for the clustering results
    pred2true_label = [int(pred2true_dict[val]) for val in pred_label]
    true_label = [int(val) for val in comp_true_label]
    lab_list = np.unique(list(pred2true_dict.values())).tolist()
    metrics_lab = ModelEvaluate_orig(true_label, pred2true_label, lab_list)
    print('fscore: ', metrics_lab[1][3])

    # Calculate the other metrics given the labels
    NMI_val = normalized_mutual_info_score(true_label, pred2true_label)
    ARI_val = adjusted_rand_score(true_label, pred2true_label)
    h_val = homogeneity_score(true_label, pred2true_label)
    c_val = completeness_score(true_label, pred2true_label)
    v_val = v_measure_score(true_label, pred2true_label)
    # metrics_lab[1].append(purity_all[0])
    metrics_lab[1].append(NMI_val)
    metrics_lab[1].append(ARI_val)
    metrics_lab[1].append(h_val)
    metrics_lab[1].append(c_val)
    metrics_lab[1].append(v_val)

    # Check the metrics for each cluster
    metrics_cluster = []
    clus_lab_list = np.arange(int(max( np.unique(list(pred2true_dict.values())) + 1).tolist()))
    for clusIdx in list(pred2true_dict.keys()):
        tmpIdx = [idx for idx, val in enumerate(pred_label) if val == clusIdx]
        tmp_pred = [pred2true_dict[clusIdx]] * len(tmpIdx)
        tmp_true = [true_label[idx] for idx in tmpIdx]
        tmp_cm = confusion_matrix(tmp_true, tmp_pred, labels=clus_lab_list)
        metrics_cluster.append(tmp_cm)

    return BIC, purity_all, metrics_lab, metrics_cluster, pred2true_dict, pred2true_label


# Check the clustering performance
def CheckTICCPerform(TICC_return, true_label, idx_ToCompare, true_label_dict, tolerence=True,
                     skip_num=3, skip_pattern='outside', verbo=False):
    ## Check the results with tolerence for boundaries
    if tolerence:
        # Check the size of each block
        end_idx = list(np.cumsum([sum(1 for i in g) for k, g in itertools.groupby(true_label)]))
        start_idx = [0] + end_idx[:-1]

        eval_idx = list(np.arange(skip_num))
        for tmp_idx in range(len(start_idx)):
            eval_idx.extend(list(np.arange(start_idx[tmp_idx] + skip_num, end_idx[tmp_idx] - skip_num)))
        eval_idx.extend(idx_ToCompare[-skip_num:])

        if skip_pattern == 'outside':
            eval_idx = eval_idx
        elif skip_pattern == 'inside':
            eval_idx = sorted(list(set(np.arange(end_idx[-1])) - set(eval_idx)))
    else:
        eval_idx = idx_ToCompare

    TICC_metric = ClusterEvaluation(TICC_return, true_label, eval_idx, true_label_dict, verbo=verbo)

    ## Check the clustering results
    true_lab_list = [true_label_dict[i] for i in sorted(set(TICC_metric[4].values()))]
    if verbo:
        print('For overall clusters: ')
        print('Columns in the confusion matrix: ', true_lab_list)
        DisplayEvaluationResults(TICC_metric[2])
    lab_df = pd.DataFrame(TICC_metric[2][0], columns=true_lab_list, index=true_lab_list)

    if verbo: print('\nFor each cluster: ')
    cluster_name = [true_label_dict[i] for i in list(TICC_metric[4].values())]  # Name of each cluster
    cluster_idx = list(TICC_metric[4].keys())  # Index of cluster
    cluster_order_dict = {i: cluster_idx[i] for i in range(len(cluster_idx))}  # Cluster order to cluster index

    # Initialize a dataframe to save the result for each cluster
    cluster_idx2name = [(cluster_order_dict[i], cluster_name[i]) for i in range(len(cluster_name))]
    cluster_df = pd.DataFrame(0, columns=true_lab_list, index=[i[0] for i in cluster_idx2name])
    cluster_df['clus_lab'] = [i[1] for i in cluster_idx2name]
    cluster_acc, cluster_sum, count_idx = [], [], 0  # Accuracy & size of each cluster, index of non-empty cluster

    for mIdx in range(max(cluster_idx) + 1):
        if mIdx in cluster_idx:
            gt_lab = true_label_dict[TICC_metric[4][mIdx]]
            if verbo:
                print('Cluster {}:'.format(mIdx))
                print('Ground-truth label: {}'.format(gt_lab))
                print('Inside the cluster: ')
            pred_counts = TICC_metric[3][count_idx][:, int(TICC_metric[4][mIdx])]
            pred_acc = []
            for tmpIdx in range(len(pred_counts)):
                if pred_counts[tmpIdx] != 0:
                    tmp_acc = pred_counts[tmpIdx] / sum(pred_counts)
                    pred_acc.append(tmp_acc)
                    if verbo: print(
                        "- {}({:.4}%) are {}".format(pred_counts[tmpIdx], tmp_acc * 100, true_label_dict[tmpIdx]))
                    cluster_df.loc[cluster_order_dict[count_idx], true_label_dict[tmpIdx]] = pred_counts[tmpIdx]
            cluster_acc.append(max(pred_acc))
            cluster_sum.append(sum(pred_counts))
            if verbo: print('-------------------------------------------------------')
            count_idx += 1

    cluster_df['cluster_size'] = cluster_sum
    cluster_df['cluster_acc'] = cluster_acc

    return lab_df, cluster_df, cluster_idx2name, TICC_metric

# ----------------------------------------------------------------------------------------------------------
# Calculate the PageRank given the connection network M
def pagerank(M, num_iterations: int = 100, d: float = 0.85):
    """PageRank: The trillion dollar algorithm.

    Parameters
    ----------
    M : numpy array
        adjacency matrix where M_i,j represents the link from 'j' to 'i', such that for all 'j'
        sum(i, M_i,j) = 1
    num_iterations : int, optional
        number of iterations, by default 100
    d : float, optional
        damping factor, by default 0.85

    Returns
    -------
    numpy array
        a vector of ranks such that v_i is the i-th rank from [0, 1],
        v sums to 1
    """
    N = M.shape[1]
    v = np.random.rand(N, 1)
    v = v / np.linalg.norm(v, 1)
    M_hat = (d * M + (1 - d) / N)
    for i in range(num_iterations):
        v = M_hat @ v
    v = v / np.linalg.norm(v, 1)
    return v


# Calculate the pagerank for each feature
def CalFeatPageRank(TICC_return, n_clusters, sel_feat, agg_pattern, threshold=1e-1):
    n_feat = len(sel_feat)
    pageRank_list = []

    for clusterIdx in range(n_clusters):
        tmp_invcov = copy.deepcopy(TICC_return[4][(n_clusters, clusterIdx)])
        n_block = int(np.shape(tmp_invcov)[0] / n_feat)

        # Sum up the values of each block according to rows and columns
        if agg_pattern == 'agg_incov':
            agg_incov = np.zeros((n_feat, n_feat))
            for rb_idx in range(n_block):
                for cb_idx in range(n_block):
                    agg_incov += tmp_invcov[rb_idx * n_feat:(rb_idx + 1) * n_feat,
                                 cb_idx * n_feat:(cb_idx + 1) * n_feat]
            agg_incov = agg_incov / (n_block * n_block)

            agg_incov[np.where(agg_incov < threshold)] = 0
            tmpPageRank = pagerank(agg_incov, d=0.85)
            pageRank_list.append([list(i)[0] for i in tmpPageRank])

        elif agg_pattern == 'agg_pagerank':
            tmp_invcov[np.where(tmp_invcov < threshold)] = 0
            tmpPageRank = [list(i)[0] for i in pagerank(tmp_invcov, d=0.85)]
            meanPageRank = np.mean([tmpPageRank[n * n_feat:(n + 1) * n_feat] for n in range(n_block)], axis=0)
            pageRank_list.append(meanPageRank)

        elif agg_pattern == 'single_incov':
            single_incov = tmp_invcov[:n_feat, :n_feat]
            single_incov[np.where(single_incov < threshold)] = 0
            tmpPageRank = [list(i)[0] for i in pagerank(single_incov, d=0.85)]
            pageRank_list.append(tmpPageRank)

    return pageRank_list


# Calculate the pagerank / betweenness centrality for each feature
def CalFeatMRFMetrics(TICC_return, n_clusters, sel_feat, agg_pattern, metric='pageRank', threshold=1e-1):
    n_feat = len(sel_feat)
    measure_list = []

    for clusterIdx in range(n_clusters):
        tmp_invcov = copy.deepcopy(TICC_return[4][(n_clusters, clusterIdx)])
        n_block = int(np.shape(tmp_invcov)[0] / n_feat)

        # Sum up the values of each block according to rows and columns
        if agg_pattern == 'agg_incov':
            incov = np.zeros((n_feat, n_feat))
            for rb_idx in range(n_block):
                for cb_idx in range(n_block):
                    incov += tmp_invcov[rb_idx * n_feat:(rb_idx + 1) * n_feat,
                                 cb_idx * n_feat:(cb_idx + 1) * n_feat]
            incov = incov / (n_block * n_block)
            incov[np.where(incov < threshold)] = 0
            if metric == 'pageRank':
                tmpPR = pagerank(incov, d=0.85)
                measure_list.append([list(i)[0] for i in tmpPR])
            elif metric == 'betweenCentral':
                tmpGraph = nx.DiGraph(incov)
                tmpBC = nx.betweenness_centrality(tmpGraph)
                measure_list.append(list(tmpBC.values()))

        elif agg_pattern == 'agg_pagerank':
            incov = copy.deepcopy(tmp_invcov)
            incov[np.where(incov < threshold)] = 0
            if metric == 'pageRank':
                tmpPR = [list(i)[0] for i in pagerank(incov, d=0.85)]
                meanPR = np.mean([tmpPR[n * n_feat:(n + 1) * n_feat] for n in range(n_block)], axis=0)
                measure_list.append(meanPR)
            elif metric == 'betweenCentral':
                tmpGraph = nx.DiGraph(incov)
                tmpBC = list(nx.betweenness_centrality(tmpGraph) .values())
                meanBC = np.mean([tmpBC[n * n_feat:(n + 1) * n_feat] for n in range(n_block)], axis=0)
                measure_list.append(meanBC)

        elif agg_pattern == 'single_incov':
            incov = tmp_invcov[:n_feat, :n_feat]
            incov[np.where(incov < threshold)] = 0
            if metric == 'pageRank':
                tmpPR = [list(i)[0] for i in pagerank(incov, d=0.85)]
                measure_list.append(tmpPR)
            elif metric == 'betweenCentral':
                tmpGraph = nx.DiGraph(incov)
                tmpBC = nx.betweenness_centrality(tmpGraph)
                measure_list.append(list(tmpBC.values()))

    return measure_list


# Plot the MRF learned patterns
# MRF_metrics is a list of lists records the importance of each feature contributing the each cluster
def PlotMRFMetrics(TICC_metric, MRF_metrics, sel_feat, true_label_dict,
                   plot_cols=5, col_size=3, row_size=1.8, x_labels='feat_idx', x_rotate=False):
    plot_rows = math.ceil(len(MRF_metrics) / plot_cols)
    if x_labels == 'feat_idx':
        plot_col_lab = ['f' + str(i + 1) for i in range(len(sel_feat))]
    elif x_labels == 'feat_name':
        plot_col_lab = sel_feat
    cluster_name = [true_label_dict[i] for i in list(TICC_metric[4].values())]

    plt.figure(figsize=(col_size * plot_cols, row_size * plot_cols))
    for i in range(len(cluster_name)):
        x = np.array(np.arange(len(sel_feat)))
        plt.subplot(plot_rows, plot_cols, i + 1)
        if x_rotate:
            plt.xticks(x, plot_col_lab, rotation=90)
        else:
            plt.xticks(x, plot_col_lab)
        plt.plot(x, MRF_metrics[i])
        plt.title(cluster_name[i])

    plt.subplots_adjust(hspace=0.15 * plot_cols, wspace=0.4)
    plt.show()


# Plot the MRF learned patterns for each label
def PlotMRFMetricsPerLabel(TICC_metric, MRF_metrics, sel_feat, true_label_dict,
                           plot_cols=5, col_size=3, row_size=1.8, x_labels='feat_idx', x_rotate=False):
    if x_labels == 'feat_idx':
        plot_col_lab = ['f' + str(i + 1) for i in range(len(sel_feat))]
    elif x_labels == 'feat_name':
        plot_col_lab = sel_feat
    cluster_name = [true_label_dict[i] for i in list(TICC_metric[4].values())]

    for feat in set(cluster_name):
        print('Label: ', feat)
        cluster_idx = [idx for idx, val in enumerate(cluster_name) if val == feat]
        x = np.array(np.arange(len(sel_feat)))
        plot_rows = math.ceil(len(cluster_idx) / plot_cols)
        plt.figure(figsize=(col_size * plot_cols, row_size * plot_rows))
        print('Cluster Index: {}'.format(cluster_idx))

        for i in range(len(cluster_idx)):
            plt.subplot(plot_rows, plot_cols, i + 1)
            if x_rotate:
                plt.xticks(x, plot_col_lab, rotation=90)
            else:
                plt.xticks(x, plot_col_lab)
            plt.plot(x, MRF_metrics[cluster_idx[i]])
            plt.title(cluster_name[cluster_idx[i]])

        plt.subplots_adjust(hspace=0.15 * plot_cols, wspace=0.4)
        plt.show()


# Given the window size, lambda, and beta, Plot the elbow with different cluster numbers
def PlotClusterNumElbow(cn_metric_list, clusters_num_list):
    # Plot the BIC and ACC with different cluster numbers
    selIdxToPlot = list(np.arange(len(clusters_num_list)))
    BIC_list = [cn_metric_list[i][0] for i in range(len(selIdxToPlot))]
    ACC_list = [cn_metric_list[i][1][0] for i in range(len(selIdxToPlot))]

    # Plot the BIC
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(selIdxToPlot, BIC_list, marker='o')
    plt.xticks(selIdxToPlot, [clusters_num_list[i] for i in selIdxToPlot])
    plt.grid(which='major', axis='y', linestyle='--')
    plt.xlabel('#Cluter')
    plt.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
    plt.ylabel('BIC')

    # Plot the ACC
    plt.subplot(1, 2, 2)
    plt.plot(selIdxToPlot, ACC_list, marker='o')
    plt.xticks(selIdxToPlot, [clusters_num_list[i] for i in selIdxToPlot])
    plt.grid(which='major', axis='y', linestyle='--')
    plt.xlabel('#Cluter')
    plt.ylabel('ACC')
    plt.subplots_adjust(hspace=0.15, wspace=0.4)
    plt.show()

    return BIC_list, ACC_list


# Given the window size, lambda, and beta, Plot the elbow with different cluster numbers for different metrics
def PlotAllClusterNumElbow(cn_metric_list, clusters_num_list, metric_list, xticks_unit=1):
    selIdxToPlot = list(np.arange(len(clusters_num_list)))

    # Plot the metrics
    plt.figure(figsize=(2.5 * len(cn_metric_list[0]), 2))

    for idx in range(len(metric_list)):
        plt.subplot(1, len(metric_list), idx + 1)
        tmp_metric = [cn_metric_list[i][idx] for i in range(len(selIdxToPlot))]
        plt.plot(selIdxToPlot, tmp_metric, marker='o')
        if xticks_unit == 1:
            plt.xticks(selIdxToPlot, [clusters_num_list[i] for i in selIdxToPlot])
        else:
            plt.xticks(list(np.arange(0, len(clusters_num_list), xticks_unit)),
                       [clusters_num_list[i] for i in selIdxToPlot if i % xticks_unit == 0])

        plt.grid(which='major', axis='y', linestyle='--')
        plt.xlabel('#Cluter')
        if max([abs(i) for i in tmp_metric]) > 100:
            plt.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
        plt.ylabel(metric_list[idx])

    plt.subplots_adjust(hspace=0.15, wspace=0.4)
    plt.show()

# ----------------------------------------------------------------------------------------------------------
# Check the parameters based on grid search results
def GridSearchLearnedPara(dataset, input_pattern, interval_pattern, window_pattern, dynamic_pattern, dynamic_attention):
    gs_folder = ('../Data/' + dataset + '/gridSearch/results/' + input_pattern + '/' +
                 interval_pattern + '/' + window_pattern + '/' + dynamic_pattern + '/' + dynamic_attention + '/')

    # All grid search files in the folder
    gs_files = os.listdir(gs_folder)

    # Load the metrics from folder
    BIC_list, purity_list, f_score_list = [], [], []
    for tmp_file in gs_files:
        tmp_path = gs_folder + tmp_file
        with open(tmp_path, 'rb') as filehandle:
            load_file = pickle.load(filehandle)
            BIC_list.append(load_file[0])
            purity_list.append(load_file[1][0])
            f_score_list.append(load_file[2][1][3])

    # Determine the optimal parameter considering all three measurements
    sel_idx_list = [np.argmin(BIC_list), np.argmax(purity_list), np.argmax(f_score_list)]
    sel_idx = max(set(sel_idx_list), key=sel_idx_list.count)

    # List of all parameters
    if window_pattern == 'dynamic':
        para_list = ['fws', 'nc', 'ld', 'bt', 'dw']
    else:
        para_list = ['fws', 'nc', 'ld', 'bt']
    end_sign = {'fws': '_', 'nc': '_', 'ld': '_', 'bt': '.', 'dw': '_'}

    # Extract the parameters from the file name
    para_dict = {}
    for tmp_para in para_list:
        start_idx = gs_files[sel_idx].index(tmp_para + '-') + len(tmp_para + '-')
        end_idx = gs_files[sel_idx][start_idx:].index(end_sign[tmp_para])
        if tmp_para == 'dw':
            para_dict[tmp_para] = [float(i) for i in re.split('-', gs_files[sel_idx][start_idx:start_idx + end_idx])]
        else:
            para_dict[tmp_para] = float(gs_files[sel_idx][start_idx:start_idx + end_idx])

    return para_dict, [BIC_list, purity_list, f_score_list], gs_files



# ----------------------------------------------------------------------------------------------------------

def corrected_dependent_ttest(data1, data2, n_training_folds, n_test_folds, alpha):
    n = len(data1)
    differences = [(data1[i]-data2[i]) for i in range(n)]
    sd = stdev(differences)
    divisor = 1 / n * sum(differences)
    test_training_ratio = n_test_folds / n_training_folds
    denominator = sqrt(1 / n + test_training_ratio) * sd
    t_stat = divisor / denominator
    # degrees of freedom
    df = n - 1
    # calculate the critical value
    cv = t.ppf(1.0 - alpha, df)
    # calculate the p-value
    p = (1.0 - t.cdf(abs(t_stat), df)) * 2.0
    # return everything
    return t_stat, df, cv, p


def CheckDifferentModelbyTtest(comp_path_dict, comp_set_dict, model_ind, sel_model, base_path,
                               fold_num=2, ttest_repeat=5, metric_num=4, eval_method='ttest', cv_ttest=True):
    # For different settings: training, outside of the training tolerance window, inside the training tolerance window
    #                         testing, outside of the testing tolerance window, inside the testing tolerance window
    all_tt = []
    for set_idx in range(6):  # Index of setting (0: all training data)
        # For each metric of accuracy, recall, precision, and f-score
        set_tt, a1, a2 = [], [], []
        for metric_idx in range(metric_num):  # Metrics (0: accuracy)
            # For each repeat of the cv
            s_sqr_sum = []
            if eval_method == 'ttest':
                ttest_repeat = 1
            for rt_idx in range(ttest_repeat):
                # Get the metrics for different methods to compare
                A_metrics, B_metrics = [], []
                # For each method, access the the stored metrics
                for tmp_model in model_ind[sel_model]:
                    postfix = comp_path_dict[tmp_model]
                    # Load the results for each CV fold
                    for cvIdx in range(fold_num):  # fold_num
                        if cv_ttest:
                            prefix = base_path + str(rt_idx) + '/'
                        else:
                            prefix = base_path + '/'
                        tmp_path = prefix + 'fold_' + str(cvIdx) + postfix + 'metrics.data'
                        with open(tmp_path, 'rb') as filehandle:
                            tmp_metrics = pickle.load(filehandle)

                        # Slice the metric for compare
                        repeat_idx = 0  # index of the repeat (0: for all methods except for single input)
                        tmp_val = tmp_metrics[set_idx][repeat_idx][metric_idx]
                        # Append the metric to the corresponding result list
                        eval(comp_set_dict[tmp_model] + '_metrics').append(tmp_val)

                # Calculate the t-test components
                p = [A_metrics[i] - B_metrics[i] for i in range(len(A_metrics))]
                p_mean = np.mean(p)
                s_sqr = np.sum([(i - p_mean) ** 2 for i in p])
                s_sqr_sum.append(s_sqr)
                if rt_idx == 0:
                    p1_1 = p[1]
                # ---------------------------------
                a1.extend(A_metrics)
                a2.extend(B_metrics)

            # Get the t-statistics by different method
            if eval_method == 'ttest':
                tmp_p = stats.ttest_rel(a1, a2)[1]
            elif eval_method == 'correct_ttest':
                if metric_idx < 3:  # Training
                    test_num = fold_num - 1
                else:  # Testing
                    test_num = fold_num
                tmp_p = corrected_dependent_ttest(a1, a2, fold_num - 1, test_num, alpha=0.05)[-1]
            elif eval_method == 'multi_ttest':
                tt = p1_1 / np.sqrt(np.mean(s_sqr_sum))
                p_val = stats.t.sf(np.abs(tt), ttest_repeat - 1) * 2
                tmp_p = p_val

            if tmp_p < 0.01:
                set_tt.append(str(tmp_p)[:6] + '**')
            elif tmp_p < 0.05:
                set_tt.append(str(tmp_p)[:6] + '*')
            else:
                set_tt.append(str(tmp_p)[:6])

        all_tt.append(set_tt)

    metric_cols = ['accuracy', 'recall', 'precision', 'fscore', 'NMI', 'ARI',
                   'h_score', 'c_score', 'v_measure']
    ttest_df = pd.DataFrame(np.array(all_tt), columns=metric_cols[:metric_num],
                            index=['train', 'out_train', 'in_train', 'test', 'out_test', 'in_test'])
    return ttest_df


# Generate the contingency table based on loaded labels from files
def GenerateContigencyTable(comp_path_dict, comp_set_dict, model_ind, sel_model, sel_data, prefix,
                            fold_check=1):
    A_corr_idx, A_incorr_idx, B_corr_idx, B_incorr_idx = [], [], [], []

    for tmp_model in model_ind[sel_model]:
        postfix = comp_path_dict[tmp_model]

        for cvIdx in range(fold_check):  # fold_check: 1/fold_num
            with open(prefix + 'fold_' + str(cvIdx) + postfix + 'pred_label.data', 'rb') as filehandle:
                tmp_predlabs = pickle.load(filehandle)

            with open(prefix + 'fold_' + str(cvIdx) + postfix + 'true_label.data', 'rb') as filehandle:
                tmp_truelabs = pickle.load(filehandle)

            corr_idx, incorr_idx = [], []
            # Check the correctness for each label
            for lab in set(tmp_truelabs):
                # Find the corresponding cluster label
                tmp_idx = [idx for idx, val in enumerate(tmp_truelabs) if val == lab]

                # Get the correct and incorrect index
                tmp_corr = [idx for idx, val in enumerate(tmp_predlabs) if (idx in tmp_idx and val == lab)]
                tmp_incorr = [idx for idx, val in enumerate(tmp_predlabs) if (idx in tmp_idx and val != lab)]
                corr_idx.extend(tmp_corr)
                incorr_idx.extend(tmp_incorr)

            eval(comp_set_dict[tmp_model] + '_corr_idx').extend(corr_idx)
            eval(comp_set_dict[tmp_model] + '_incorr_idx').extend(incorr_idx)

    # ------------------------------------------------------------------------
    # Generate the contingency table
    x1 = len(set(A_corr_idx) & set(B_corr_idx))
    x2 = len(set(A_corr_idx) & set(B_incorr_idx))
    x3 = len(set(A_incorr_idx) & set(B_corr_idx))
    x4 = len(set(A_incorr_idx) & set(B_incorr_idx))
    table = [[x1, x2], [x3, x4]]
    print('Contingency table: ', table)
    return table


# Measure the difference of two models by mcnemar tests
# Input: contingency table
def CalMcnemarPvalue(table):
    # calculate mcnemar test
    result = mcnemar(table, exact=False, correction=True)
    # summarize the finding
    print('statistic=%.3f, p-value=%.6f' % (result.statistic, result.pvalue))
    # interpret the p-value
    alpha = 0.05
    if result.pvalue > alpha:
        print('Same proportions of errors (fail to reject H0)')
    else:
        print('Different proportions of errors (reject H0)')
    return result


# Prepare the data to conduct the t-test
def TtestDataPreparation(model_ind, cvar_ind, sel_model, sel_data, comp_list, metric_parafix, base_path,
                         tt_measure='mcnemar', cv_ttest=True):
    # Dictionary mapping to the two models A & B to compare
    comp_set_dict = {model_ind[sel_model][i]: comp_list[i] for i in range(2)}

    # Get the paths to load the predicted labels
    if tt_measure == 'mcnemar':
        file_pre = '/' + sel_data + '_'
        file_post = 'rp-' + str(0) + '_' #pred_label.data'
    elif 'ttest' in tt_measure:
        file_pre = metric_parafix + '_'
        file_post = ''
#         file_post = 'metrics.data'

    if sel_model == 'input' or sel_model == 'input_time':
        tmp_postfix = [file_pre + model_ind[sel_model][i] + '_' + cvar_ind[sel_model][i] + '_' + file_post
                       for i in range(2)]
    elif sel_model == 'time':
        tmp_postfix = [file_pre + cvar_ind[sel_model][i] + '_' + model_ind[sel_model][i] + '_' + file_post
                       for i in range(2)]

    comp_path_dict = {model_ind[sel_model][i]: tmp_postfix[i] for i in range(len(comp_list))}

    if cv_ttest:
        tmp_prefix = base_path + str(0) + '/'
    else:
        tmp_prefix = base_path + '/'

    return comp_path_dict, comp_set_dict, tmp_prefix