import numpy as np
import scipy.sparse as sp
import numba
from ppr_solver import *
import sys
import os

all_graph = ['Cora', 'Citeseer', 'ogbn-arxiv', 'as-skitter', 'ogbn-proteins', 'com-orkut',
               'cit-patent', 'ogbl-ppa', 'ogbn-products', 'wiki-talk', 'com-youtube', 'ogbn-mag',
               'soc-lj1', 'reddit', 'pubmed', 'wiki-en21', 'com-friendster', 'ogbn-papers100M']

graph_names = ['ogbn-papers100M']
path = './dataset/'
test_num = 50
alpha = 0.1
fwd_omega = 1.0
alpha = alpha / (2 - alpha)
mu = (1. - alpha) / (1. + alpha)
omega = 1. + (mu / (1. + np.sqrt(1. - mu ** 2.))) ** 2.
alpha = 2 * alpha / (1 + alpha)
exp_result = np.load('./results/ppr_exp_result_new.npy', allow_pickle=True)
exp_result = exp_result.reshape(1, -1)[0][0]

for graph_name in graph_names:
    if graph_name == 'ogbn-papers100M':
        batchsize = 5
    else:
        batchsize = 50
    dataset = np.load('./dataset/' + graph_name + '-exp.npz', allow_pickle=True)
    indptr = dataset['indptr']
    indices = dataset['indices']
    degree = dataset['degree']
    Us = dataset['Us']
    Vs = dataset['Vs']
    n = len(indptr) - 1
    m = len(indices)
    eps = 1e-2/n
    snapshots = len(Us)
    sor_opers = np.zeros(snapshots + 1, dtype=np.int64)
    sor_algo_times = np.zeros(snapshots + 1)
    sor_opers_dir = np.zeros(snapshots + 1, dtype=np.int64)
    sor_algo_times_dir = np.zeros(snapshots + 1)

    fwd_opers = np.zeros(snapshots + 1, dtype=np.int64)
    fwd_algo_times = np.zeros(snapshots + 1)
    fwd_opers_dir = np.zeros(snapshots + 1, dtype=np.int64)
    fwd_algo_times_dir = np.zeros(snapshots + 1)
    np.random.seed(17)
    s_nodes = exp_result[graph_name][0]
    if graph_name == 'ogbn-papers100M':
        s_nodes = np.random.choice(s_nodes, size=5)
    sor_oper, sor_algo_time, xs_sor, rs_sor = init_a_graph(n, indptr, indices, degree, alpha, eps, omega, graph_name, s_nodes)
    sor_oper_dir = sor_oper.copy()
    sor_algo_time_dir = sor_algo_time.copy()
    sor_opers[0] = sor_oper.mean()
    sor_algo_times[0] = sor_algo_time.mean()
    sor_opers_dir[0] = sor_oper.mean()
    sor_algo_times_dir[0] = sor_algo_time.mean()
    fwd_oper, fwd_algo_time, xs_fwd, rs_fwd = init_a_graph(n, indptr, indices, degree, alpha, eps, omega=1.0, graph_name=graph_name,s_nodes=s_nodes)
    fwd_oper_dir = fwd_oper.copy()
    fwd_algo_time_dir = fwd_algo_time.copy()
    fwd_opers[0] = fwd_oper.mean()
    fwd_algo_times[0] = fwd_algo_time.mean()
    fwd_opers_dir[0] = fwd_oper.mean()
    fwd_algo_times_dir[0] = fwd_algo_time.mean()

    indptr_tmp = indptr.copy()
    indices_tmp = indices.copy()
    degree_tmp = degree.copy()
    print(0, ':', sor_opers[0], sor_algo_times[0], fwd_opers[0], fwd_algo_times[0],
          sor_opers_dir[0], sor_algo_times_dir[0], fwd_opers_dir[0], fwd_algo_times_dir[0])
    for snapshot in range(snapshots):
        us = Us[snapshot]
        vs = Vs[snapshot]
        sor_oper_add, sor_algo_time_add, xs_sor, rs_sor = update_ppr_parallel(n, indptr_tmp, indices_tmp, degree_tmp, alpha, eps, omega, xs_sor, rs_sor, us, vs, batchsize)
        sor_oper += sor_oper_add
        sor_algo_time += sor_algo_time_add
        sor_opers[snapshot + 1] = sor_oper.mean()
        sor_algo_times[snapshot + 1] = sor_algo_time.mean()

        fwd_oper_add, fwd_algo_time_add, xs_fwd, rs_fwd = update_ppr_parallel(n, indptr, indices, degree, alpha, eps, fwd_omega, xs_fwd, rs_fwd, us, vs, batchsize)
        fwd_oper += fwd_oper_add
        fwd_algo_time += fwd_algo_time_add
        fwd_opers[snapshot + 1] = fwd_oper.mean()
        fwd_algo_times[snapshot + 1] = fwd_algo_time.mean()

        sor_oper_add, sor_algo_time_add, xs_dir, rs_dir = init_a_graph(n, indptr, indices, degree, alpha, eps, omega,
                                                               graph_name, s_nodes)
        sor_oper_dir += sor_oper_add
        sor_algo_time_dir += sor_algo_time_add
        sor_opers_dir[snapshot + 1] = sor_oper_dir.mean()
        sor_algo_times_dir[snapshot + 1] = sor_algo_time_dir.mean()

        fwd_oper_add, fwd_algo_time_add, _, _ = init_a_graph(n, indptr, indices, degree, alpha, eps, fwd_omega,
                                                             graph_name, s_nodes)
        fwd_oper_dir += fwd_oper_add
        fwd_algo_time_dir += fwd_algo_time_add
        fwd_opers_dir[snapshot + 1] = fwd_oper_dir.mean()
        fwd_algo_times_dir[snapshot + 1] = fwd_algo_time_dir.mean()

        print(snapshot + 1, ':', sor_opers[snapshot + 1], sor_algo_times[snapshot + 1], fwd_opers[snapshot + 1], fwd_algo_times[snapshot + 1],
              sor_opers_dir[snapshot + 1], sor_algo_times_dir[snapshot + 1], fwd_opers_dir[snapshot + 1], fwd_algo_times_dir[snapshot + 1])
    np.savez('./results/'+graph_name+'/dygraph_result_new.npz', sor_opers=sor_opers,
             sor_algo_times=sor_algo_times, fwd_opers=fwd_opers, fwd_algo_times=fwd_algo_times,
             sor_opers_dir=sor_opers_dir, sor_algo_times_dir=sor_algo_times_dir, fwd_opers_dir=fwd_opers_dir,
             fwd_algo_times_dir=fwd_algo_times_dir)








