import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns

def random_support(m, n, ratio):
    supp_size = int(m * n * ratio)
    score_matrix = np.random.randn(m, n)
    support = np.zeros((m,n))
    order = np.argsort(score_matrix, axis = None)[-supp_size:]
    row_index = order // n
    col_index = order % n
    support[row_index, col_index] = 1
    return support

def sufficient_nonclosedness(support1, support2):
    assert support1.shape[1] == support2.shape[0]
    m = support1.shape[0]
    r = support1.shape[1]
    n = support2.shape[1]
    product = support1 @ support2
    for i in range(r):
        rank_one_support = np.tensordot(support1[:,i], support2[i,:], axes = 0)
        unique_and_outside_support = np.where((product - rank_one_support) == 0, 1, 0)
        unique_support = np.where((unique_and_outside_support + rank_one_support) == 2, 1, 0)

        list_index = []
        for j in range(n):
            if np.sum(unique_support[:,j]) == 0:
                continue
            list_index.append(j)
        for j in list_index:
            for k in list_index:
                if k <= j:
                    continue
                if np.all(unique_support[:, j] == unique_support[:,k]):
                    continue
                if np.sum(unique_support[:,j]) + np.sum(unique_support[:,k]) == np.sum(np.maximum(unique_support[:,j], unique_support[:,k])):
                    continue
                return False
    return True

if __name__ == "__main__":
    pool = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    result = np.zeros((len(pool),len(pool)))
    n_samples = 100
    size = 100
    for i in range(len(pool)):
        for j in range(len(pool)):
            for k in range(n_samples):
                support1 = random_support(size, size, pool[i])
                support2 = random_support(size, size, pool[j])
                if not sufficient_nonclosedness(support1, support2):
                    result[i,j] += 1
    
    result = result * 1.0 / n_samples
    # fig, ax = plt.subplots()
    # heatmap = ax.pcolor(result)
    # ax.set_xticks(np.arange(result.shape[1]) + 0.5, minor=False)
    # ax.set_yticks(np.arange(result.shape[0]) + 0.5, minor=False)

    # ax.set_xticklabels(pool, minor=False)
    # ax.set_yticklabels(pool, minor=False)
    # plt.color_bar()
    # plt.show()
    ax = sns.heatmap(result, annot = True, linewidths = .5)
    ax.set_xticks(np.arange(result.shape[1]) + 0.5, minor=False)
    ax.set_yticks(np.arange(result.shape[0]) + 0.5, minor=False)

    ax.set_xticklabels(pool, minor=False)
    ax.set_yticklabels(pool, minor=False)
    ax.xaxis.tick_top()
    plt.savefig("proba.png", dpi = 200)
    
    