import numpy as np
import scipy.sparse as scsp
import numba
from katz_solver import *

finished = ['ogbl-ppa', 'ogbn-products', 'wiki-talk', 'com-youtube', 'ogbn-mag', 'soc-lj1',
               'reddit', 'pubmed', 'wiki-en21', 'Cora', 'Citeseer', 'ogbn-arxiv', 'as-skitter', 'ogbn-proteins', 'com-orkut', 'cit-patent',
               'ogbn-papers100M']
graph_names = ['wiki-talk']
path = './dataset/'
test_num = 50
lambda_result = np.load('./results/katz_exp_lambda.npy', allow_pickle=True)
lambda_result = lambda_result.reshape(1, -1)[0][0]

exp_result = np.load('./results/ppr_exp_result_new.npy', allow_pickle=True)
exp_result = exp_result.reshape(1, -1)[0][0]

for graph_name in graph_names:
    all_result = {}
    graph_path = path + graph_name + '/'
    adj_matrix = scsp.load_npz(graph_path + graph_name + '_csr-mat.npz')
    indices = adj_matrix.indices
    indptr = adj_matrix.indptr
    n = len(indptr) - 1
    m = len(indices)
    lambda_1 = lambda_result[graph_name]
    alpha = 1. / (lambda_1 + 1)
    print(2. / (1 + np.sqrt(1 - (alpha * lambda_1) ** 2)))
    np.random.seed(17)
    s_nodes = exp_result[graph_name][0]
    degree = np.array(adj_matrix.sum(1)).flatten()
    eps = 1e-4 / m
    for omega in np.arange(0.8,1.96,0.05):
        if omega in all_result:
            print(graph_name, omega, all_result[omega][1].mean(),  all_result[omega][5].mean())
            continue
        print(graph_name, omega, end=' ')
        graph_result = solve_a_graph_omega(n, indptr, indices, degree, alpha, eps, omega, graph_name, s_nodes)
        print(graph_result[1].mean(), graph_result[5].mean())
        all_result[omega] = graph_result
        np.save('./results/' + graph_name +'/katz_exp_omega_result_new.npy', all_result)

