import numpy as np
import torch
import scipy.sparse as sp
import numba
from hk_solver import *
import warnings

warnings.filterwarnings("ignore")
large_graph = ['com-friendster', 'ogbn-papers100M']
graph_names = ['Cora', 'Citeseer', 'ogbn-arxiv', 'as-skitter', 'ogbn-proteins', 'com-orkut', 'cit-patent', 'ogbl-ppa',
               'ogbn-products', 'wiki-talk', 'com-youtube', 'ogbn-mag', 'soc-lj1', 'reddit', 'pubmed', 'wiki-en21',
               'com-friendster', 'ogbn-papers100M']
path = './dataset/'

all_result = np.load('./results/hk_exp_pk_result_new.npy', allow_pickle=True)
all_result = all_result.reshape(1, -1)[0][0]
ppr_result = np.load('./results/ppr_exp_pk_result_new.npy', allow_pickle=True)
ppr_result = ppr_result.reshape(1, -1)[0][0]
for graph_name in graph_names:
    if graph_name in all_result:
        print(graph_name, all_result[graph_name][0], all_result[graph_name][1], all_result[graph_name][2][1].mean(), all_result[graph_name][2][2].mean(), all_result[graph_name][2][3].mean())
        continue
    graph_path = path + graph_name + '/'
    adj_matrix = sp.load_npz(graph_path + graph_name + '_csr-mat.npz')
    indices = adj_matrix.indices
    indptr = adj_matrix.indptr
    n = len(indptr) - 1
    m = len(indices)
    tau = 10.0
    omega = 1.2
    degree = np.array(adj_matrix.sum(1)).flatten()
    if graph_name == 'ogbn-papers100M' or graph_name == 'com-friendster':
        eps = 1e-4 / (m + n)
    else:
        eps = 1e-10 / (m + n)
    np.random.seed(17)
    s_nodes = ppr_result[graph_name][2][0]
    print(graph_name, len(s_nodes), end=' ')
    result = solve_a_graph_pk(n, indptr, indices, degree, tau, eps, omega, graph_name, s_nodes)
    all_result[graph_name] = (n, degree.mean(), result)
    print(all_result[graph_name][0], all_result[graph_name][1], all_result[graph_name][2][1].mean(), all_result[graph_name][2][2].mean(), all_result[graph_name][2][3].mean())
    np.save('./results/hk_exp_pk_result_new.npy', all_result)