import time
import misError
import syntheticData
import dotProductOracle
import networkx as nx
import seaborn as sns
import pandas as pd
from matplotlib import pyplot as plt
import random
import pandas as pd
import networkx as nx


# G_0: basic graph generated by SBM
# G: delete some edges in G_0
# G is the real input graph
# for G, need to find the new appropriate theta
# to find theta, use the similar approach in theta.py



n_list = [1000]
k_list = [3]
q_list = [0.002]
p_list = [0.05]

delNum = [0, 25, 32, 40, 45, 50, 55, 60, 65, 72]
theta_list = [0.0005, 0.0005, 0.0004, 0.0004, 0.00035, 0.00035, 0.0003, 0.0003, 0.0003, 0.0003]


R_init = 2000
R_query = 250
t = 25
s = 21
s_dot = 20

repeat = 5

for n_index in range(0, len(n_list)):
    for k_index in range(0, len(k_list)):
        for p_index in range(0,len(p_list)):
            for q_index in range(0,len(q_list)):
                n = n_list[n_index]
                k = k_list[k_index]
                p = p_list[p_index]
                q = q_list[q_index]
                N = k * n
                syntheticData.SBM(n, k, p, q)
                data_path = "./SyntheticData/" + "n=" + str(n) + "_k=" + str(k) + \
                            "_p=" + str(p) + "_q=" + str(q) + ".csv"

                for delIndex in range(0, len(delNum)):
                    del_num = delNum[delIndex]
                    theta = theta_list[delIndex]

                    avg_acc = 0
                    count = 0  # times that occurs isolated vertex in repeat times repeat

                    for times in range(0, repeat):
                        # read data
                        E = pd.read_csv(data_path)

                        # generate the graph
                        G = nx.Graph()
                        for i in range(0, N):
                            G.add_node(i)
                        G.add_edges_from([(u, v) for _, u, v in E.itertuples()])

                        degrees = dict(G.degree())
                        max_degree = max(degrees.values())
                        min_degree = min(degrees.values())
                        # print("max_degree=" + str(max_degree))
                        # print("min_degree=" + str(min_degree))

                        neighbor = []
                        for i in range(0, N):
                            neighbor.append(list(G.neighbors(i)))

                        # in_neighbor[i] is a list, vertex i's neighbors that are in the same cluster with i
                        # out_neighbor[i] is a list, vertex i's neighbors that are in the different clusters with i
                        in_neighbor = []
                        out_neighbor = []
                        for i in range(0, N):  # see every vertex
                            temp_in = []
                            temp_out = []
                            for j in range(0, len(neighbor[i])):  # check all the neighbors
                                vtx = neighbor[i][j]
                                if int(vtx / n) == int(i / n):  # same cluster
                                    temp_in.append(vtx)
                                else:  # different cluster
                                    temp_out.append(vtx)
                            in_neighbor.append(temp_in)
                            out_neighbor.append(temp_out)


                        file = open("./Results/robust.txt", "a")
                        file.write("n=" + str(n) + "_k=" + str(k) + "_p=" + str(p) + "_q=" + str(q) + "\n")
                        file.write("theta=" + str(theta) + "\n")
                        file.write("R_init=" + str(R_init) + "_R_query=" + str(R_query) +
                                   "_t=" + str(t) + "_s_dot=" + str(s_dot) + "_s=" + str(s) + "\n")

                        file.write("maximum degree of original graph: " + str(max_degree) + "\n")
                        file.write("minimum degree of original graph: " + str(min_degree) + "\n\n")
                        file.close()

                        del_edges = []
                        del_vtxs = []

                        flag = 0  # number of isolated vertices
                        for i in range(0, k):
                            # pick a random vertex in cluster i
                            del_vtx = random.randint(i * n, (i + 1) * n - 1)

                            temp_edges = []
                            in_deg = len(in_neighbor[del_vtx])

                            # done pick
                            del_vtxs.append(del_vtx)

                            file = open("./Results/robust.txt", "a")
                            file.write("vertex " + str(del_vtx) + " of cluster " + str(i) + " is picked\n")
                            file.write("(original) in-degree, out-degree, degree = " + str(len(in_neighbor[del_vtx])) +
                                       "," + str(len(out_neighbor[del_vtx])) + "," + str(degrees[del_vtx]) + "\n")
                            file.write("we want delete " + str(del_num) + " edges\n")

                            if in_deg <= del_num:
                                flag = flag + 1
                                # print("indegree = " + str(in_deg) + " < " + str(del_num))
                                file.write("indegree = " + str(in_deg) + " < " + str(del_num) + "\n")
                                for j in range(0, in_deg):
                                    u = in_neighbor[del_vtx][j]
                                    temp_edges.append((del_vtx, u))
                                    file.write("(" + str(del_vtx) + ", " + str(u) + ") is deleted\n")
                            else:
                                for j in range(0, del_num):
                                    temp_index = random.randint(0, len(in_neighbor[del_vtx]) - 1)
                                    u = in_neighbor[del_vtx][temp_index]
                                    temp_edges.append((del_vtx, u))

                                    # avoid picking the same neighbor more than once
                                    in_neighbor[del_vtx].remove(u)

                                    file.write("(" + str(del_vtx) + ", " + str(u) + ") is deleted\n")
                            file.write("\n")
                            file.close()
                            del_edges.append(temp_edges)
                            # print(del_edges)
                            G.remove_edges_from(temp_edges)

                        new_degrees = dict(G.degree())
                        new_max_degree = max(new_degrees.values())
                        new_min_degree = min(new_degrees.values())
                        # print("max_degree=" + str(new_max_degree))
                        # print("min_degree=" + str(new_min_degree))
                        #

                        file = open("./Results/robust.txt", "a")
                        file.write("maximum degree of new graph: " + str(new_max_degree) + "\n")
                        file.write("minimum degree of new graph: " + str(new_min_degree) + "\n")

                        new_neighbor = []
                        for i in range(0, N):
                            new_neighbor.append(list(G.neighbors(i)))

                        new_in_neighbor = []
                        new_out_neighbor = []
                        for i in range(0, N):  # see every vertex
                            temp_in = []
                            temp_out = []
                            for j in range(0, len(new_neighbor[i])):  # check all the neighbors
                                vtx = new_neighbor[i][j]
                                if int(vtx / n) == int(i / n):  # same cluster
                                    temp_in.append(vtx)
                                else:  # different cluster
                                    temp_out.append(vtx)
                            new_in_neighbor.append(temp_in)
                            new_out_neighbor.append(temp_out)

                        for i in range(0, len(del_vtxs)):
                            u = del_vtxs[i]
                            # print("(new) in-degree, out-degree, degree = " + str(len(new_in_neighbor[u])) +
                            #            "," + str(len(new_out_neighbor[u])) + "," + str(new_degrees[u]))

                            file.write("(new) in-degree, out-degree, degree = " + str(len(new_in_neighbor[u])) +
                                       "," + str(len(new_out_neighbor[u])) + "," + str(new_degrees[u]) + "\n")


                        if flag > 0:
                            count = count + 1
                        # print(str(flag) + " isolate vertices")
                        file.write(str(flag) + " isolate vertices\n")
                        file.write("----------------------\n")
                        file.close()

                        # when selecting a sampling size of N, we calculate the overall accuracy
                        random_list = random.sample(range(N), N)  # random.sample(sequence, k)
                        misError.defineRandomList(random_list)

                        # ground_truth
                        plantedClusters = misError.toClusteringSets([int(u / n) for u in random_list])

                        # our
                        inCluster = misError.ourAlgorithm(G, k, R_init, R_query, t, s, s_dot, theta, False)
                        clustering = misError.toClusteringSets(inCluster)

                        # accuracy
                        accuracy = misError.getMatching(clustering, plantedClusters)

                        # write txt
                        file = open("./Results/robust.txt", "a")
                        file.write("accuracy is: " + str(accuracy) + "\n")
                        file.write("______________________________________________________________\n\n")
                        file.close()

                        avg_acc = avg_acc + accuracy


                    file = open("./Results/robust.txt", "a")
                    file.write(str(count) + " out of " + str(repeat) + " occurs isolate vertex\n")
                    avg_acc = avg_acc / repeat
                    file.write("Average accuracy of delNum=" + str(del_num) + " is: " + str(avg_acc) + "\n")
                    file.write("Average error of delNum=" + str(del_num) + " is: " + str(1-avg_acc) + "\n")
                    file.write("------------------------------------------------------\n\n\n\n\n")
                    file.close()