import syntheticData
import random
import numpy as np
import networkx as nx
from numpy.linalg import eig
import pandas as pd
from networkx.algorithms import bipartite

# eval the query complexity

def defineRandomList(random_list):
    global my_random_list
    my_random_list = random_list

def toClusteringSets(inCluster):
    clusterNum = max(inCluster) + 1
    ClusteringSets = [set() for _ in range(clusterNum)]
    # ClusteringSets is a list, it has clusterNum entries, every entry is an empty set

    for u in range(len(inCluster)):
        ClusteringSets[inCluster[u]].add(my_random_list[u])

    # ClusteringSets is a list, has clusterNum entries, every entry is a set
    # and ClusteringSets[i] is a set, it contains the nodes that belong to cluster i
    return ClusteringSets

def getMatching(clustering, plantedClusters):
    k = len(plantedClusters)
    N = 0
    for x in clustering:
        N += len(x)

    bigraph = nx.Graph()
    leftNodes = [f'l{i}' for i in range(k)]
    rightNodes = [f'r{j}' for j in range(len(clustering))]
    bigraph.add_nodes_from(leftNodes, bipartite=0)
    bigraph.add_nodes_from(rightNodes, bipartite=1)
    for i in range(k):
        A = plantedClusters[i]
        for j in range(len(clustering)):
            cluster = clustering[j]
            B = set(cluster)
            weight = -len(A.intersection(B))
            bigraph.add_edge(f'l{i}', f'r{j}', weight=weight)

    matching = bipartite.matching.minimum_weight_full_matching(bigraph, leftNodes, 'weight')
    print(matching)

    value = 0
    for li in leftNodes:
        rj = matching[li]

        i = int(li[1:])
        j = int(rj[1:])

        A = plantedClusters[i]
        B = set(clustering[j])

        value += len(A.intersection(B))
        print(f'	got accuracy {len(A.intersection(B)) / len(A)} for cluster with {len(A)} vertices')
    return value / N



class dotProductOracle:
    def __init__(self, G, k, R_init, R_query, t, s):
        self.G = G
        self.k = k
        self.R_init = R_init
        self.R_query = R_query
        self.t = t
        self.s = s
        self.N = len(G.nodes())

        self.degrees = dict(G.degree())
        self.max_degree = max(self.degrees.values())
        self.neighbor = []
        for i in range(0, self.N):
            self.neighbor.append(list(self.G.neighbors(i)))
        self.structure = self.InitializeOracle()

    def RunRandomWalks(self, R, t, x):  # lazy random walk
        count_edge = 0
        m_x = []
        for i in range(0, self.N):
            m_x.append(0)

        for i in range(0, R):  # R times
            curr_vtx = x
            for j in range(0, t):  # t steps
                deg = self.degrees[curr_vtx]
                temp = random.random()  # a random number in (0,1)

                if temp < deg / (2 * self.max_degree):  # jump to a random neighbor
                    u = curr_vtx
                    curr_vtx = random.choice(self.neighbor[curr_vtx])
                    if u <= curr_vtx:
                        global_edge_set.add((u, curr_vtx))
                    else:
                        global_edge_set.add((curr_vtx, u))
                else:  # stay at curr_vtx
                    curr_vtx = curr_vtx
            m_x[curr_vtx] += 1

        # get probability distribution vector
        for i in range(0, self.N):
            m_x[i] = m_x[i] / R
        return m_x

    def EstimateTransitionMatrix(self, S, R, t):
        Q = []
        for i in range(0, len(S)):
            x = S[i]
            m_x = self.RunRandomWalks(R, t, x)
            Q.append(m_x)
        return np.array(Q).T

    def EstimateCollisionMatrix(self, S, R, t):
        Q = self.EstimateTransitionMatrix(S, R, t)
        P = self.EstimateTransitionMatrix(S, R, t)
        G = (np.dot(P.T, Q) + np.dot(Q.T, P)) / 2
        return G

    def InitializeOracle(self):
        print("initiate dorProductOracle")
        # random sampling
        I_S = []
        for i in range(0, self.s):
            I_S.append(int((self.N * random.random()) % self.N))

        Q = self.EstimateTransitionMatrix(I_S, self.R_init, self.t)
        G = self.EstimateCollisionMatrix(I_S, self.R_init, self.t)
        G = np.dot(G, self.N / self.s)
        vals, vecs = eig(G)
        Lambda = np.diag(vals)

        flag = True
        for i in range(0, self.s):
            if Lambda[i][i] == 0:
                flag = False
                break
        if flag:
            Lambda_2 = np.zeros((self.s, self.s))
            for i in range(0, self.s):
                Lambda_2[i][i] = 1 / (Lambda[i][i] * Lambda[i][i])
            psi = (self.N / self.s) * (np.dot(np.dot(vecs[:, :self.k], Lambda_2[:self.k, :self.k]), vecs[:, :self.k].T))
            structure = [psi, Q]

            file = open("./Results/queryComplexity.txt", "a")
            file.write("Num of edges (initiate dotProductOracle):  " + str(len(global_edge_set)) + "\n")
            file.close()

            return structure
        return None

    def SpectralDotProductOracle(self, x, y):
        m_x = self.RunRandomWalks(self.R_query, self.t, x)
        m_y = self.RunRandomWalks(self.R_query, self.t, y)

        alpha_x = np.dot(self.structure[1].T, m_x)
        alpha_y = np.dot(self.structure[1].T, m_y)
        return np.dot(np.dot(alpha_x.T, self.structure[0]), alpha_y)



class ClusteringOracle:
    def __init__(self, G, k, R_init, R_query, t, s, s_dot, theta):
        self.G = G
        self.k = k
        self.R_init = R_init
        self.R_query = R_query
        self.t = t
        self.s = s
        self.s_dot = s_dot
        self.theta = theta
        self.dotOracle = dotProductOracle(G, k, R_init, R_query, t, s_dot)

        global_edge_set.clear()
        self.H = self.constructOracle()
        while self.H == "fail":
            global_edge_set.clear()
            print("Construct oracle fail!!!")
            self.H = self.constructOracle()

        # write txt
        file = open("Results/queryComplexity.txt", "a")
        file.write("\nclear global_edge_set\n")
        file.write("Num of edges (initiate clusteringOracle):  " + str(len(global_edge_set)) + "\n")
        file.close()

    def constructOracle(self):
        S = random.sample(list(self.G.nodes()), self.s)
        H = nx.Graph()

        # add vertices to similarity graph H
        for i in range(len(S)):
            u = S[i]
            H.add_node(u)
        # add edges to H
        for i in range(len(S)):
            u = S[i]
            for j in range(i + 1, len(S)):
                v = S[j]
                apx = self.dotOracle.SpectralDotProductOracle(u, v)
                if apx >= self.theta:
                    H.add_edge(u, v)

        components = list(nx.connected_components(H))
        num = len(components)
        if num != self.k:
            print("construct clustering oracle fail!")
            print(num)
            print(components)
            global_edge_set.clear()
            return "fail"
        else:
            print("construct clustering oracle success!")
            return H

    def searchIndex(self, x):
        components = list(nx.connected_components(self.H))
        for i in range(len(components)):
            flag = True
            for u in components[i]:
                apx = self.dotOracle.SpectralDotProductOracle(u, x)
                if apx < self.theta:
                    flag = False
                    break
            if flag:
                return i
        return "outlier"

    def whichCluster(self, x):
        tmp = self.searchIndex(x)
        if tmp == "outlier":
            return random.randint(0, self.k - 1)
        else:
            return tmp




n_list = [5000]
k_list = [3]
p_list = [0.2]
q_list = [0.002]

for n_index in range(0, len(n_list)):
    for k_index in range(0, len(k_list)):
        for p_index in range(0,len(p_list)):
            for q_index in range(0,len(q_list)):
                n = n_list[n_index]
                k = k_list[k_index]
                p = p_list[p_index]
                q = q_list[q_index]

                syntheticData.SBM(n, k, p, q)
                N = k * n
                data_path = "./SyntheticData/" + "n=" + str(n) + "_k=" + str(k) + \
                            "_p=" + str(p) + "_q=" + str(q) + ".csv"

                R_init = 2000
                R_query = 100
                t = 25
                s_dot = 20

                s = int(10 * k * (np.log(k)))
                theta = 0.0001
                my_accuracy = []

                # read data
                E = pd.read_csv(data_path)
                G = nx.Graph()
                for i in range(0, N):
                    G.add_node(i)
                G.add_edges_from([(u, v) for _, u, v in E.itertuples()])
                m = G.number_of_edges()

                # write txt
                # some basic information
                file = open("Results/queryComplexity.txt", "a")
                file.write("n=" + str(n) + "_k=" + str(k) + "_p=" + str(p) + "_q=" + str(q) + "\n")
                file.write("theta=" + str(theta) + "\n")
                file.write("R_init=" + str(R_init) + "_R_query=" + str(R_query) +
                           "_t=" + str(t) + "_s_dot=" + str(s_dot) + "_s=" + str(s) + "\n\n")
                file.write("Nodes of the graph: " + str(N) + "\n")
                file.write("Edges of the graph: " + str(m) + "\n\n")
                file.close()

                # init
                global_edge_set = set()
                myOracle = ClusteringOracle(G, k, R_init, R_query, t, s, s_dot, theta)

                file = open("Results/queryComplexity.txt", "a")
                file.write("Fraction of edges (initiate clusteringOracle): " +
                           str(len(global_edge_set)/m) + "\n\n")
                file.close()

                # when selecting a sampling size of N, we calculate the overall accuracy
                random_list = random.sample(range(N), N)
                defineRandomList(random_list)

                # ground_truth
                plantedClusters = toClusteringSets([int(u / n) for u in random_list])

                print("now len(global_edge set) should be the same with txt: " + str(len(global_edge_set)))

                # our
                inCluster = []
                N = len(G.nodes())
                for i in range(len(my_random_list)):
                    old_len = len(global_edge_set)

                    u = my_random_list[i]
                    index = myOracle.whichCluster(u)
                    inCluster.append(index)
                    print("Query " + str(i+1) + " times:")
                    print(len(global_edge_set))
                    print(len(global_edge_set)/m)

                    file = open("Results/queryComplexity.txt", "a")
                    file.write("Query " + str(i+1) + " times: (num of edges -- fraction)  " +
                               str(len(global_edge_set)-old_len) + "--" + str((len(global_edge_set)-old_len)/m) + "\n")
                    file.write("-----------Fraction of edges: " + str(len(global_edge_set) / m) + "\n")
                    file.close()

                clustering = toClusteringSets(inCluster)
                accuracy = getMatching(clustering, plantedClusters)

                print("Accuracy of our algorithm: " + str(accuracy))

                file = open("Results/queryComplexity.txt", "a")
                file.write("accuracy: " + str(accuracy) + "\n")
                file.write("error: " + str(1 - accuracy) + "\n")
                file.close()
