from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from functools import partial
import os
import warnings

from absl import flags
import tensorflow as tf

from graph_data import *
from utils import *

warnings.filterwarnings("ignore")

flags.DEFINE_string('checkpoint', '', '')
flags.DEFINE_integer('random_seed', 12345, '')
flags.DEFINE_integer('tf_random_seed', 601904901297, '')

# Input example params.
flags.DEFINE_integer('train_batch_size', 32, 'Dimension of node embeddings.')
flags.DEFINE_integer('num_train_iters', 25000, 'Dimension of node embeddings.')
flags.DEFINE_string('dataset', 'graph_rnn_grid', '')
flags.DEFINE_integer('node_embedding_dim', 20, 'Dimension of node embeddings.')
flags.DEFINE_string('node_features', 'gaussian',
                    'Can be laplacian, gaussian, or zero.')
flags.DEFINE_string('output_file', '', 'Can be laplacian, gaussian, or zero.')
flags.DEFINE_integer('run_number', 0, '')
flags.DEFINE_float('gaussian_scale', 0.3,
                   'Scale to use for random Gaussian features.')
FLAGS = tf.app.flags.FLAGS

os.environ["CUDA_VISIBLE_DEVICES"] = str(FLAGS.run_number)

# Logging and print options.
np.set_printoptions(suppress=True, formatter={'float': '{: 0.3f}'.format})
tf.random.set_random_seed(FLAGS.tf_random_seed)
random.seed(FLAGS.random_seed)

NODE_FEATURES_MAP = {
    'laplacian':
    partial(add_laplacian_features, num_components=FLAGS.node_embedding_dim),
    'gaussian':
    partial(
        add_gaussian_noise_features,
        num_components=FLAGS.node_embedding_dim,
        scale=FLAGS.gaussian_scale),
    'zeros':
    partial(add_zero_features, num_components=FLAGS.node_embedding_dim),
    'positional':
    partial(
        add_positional_encoding_features,
        num_components=FLAGS.node_embedding_dim),
}
add_node_features_fn = NODE_FEATURES_MAP[FLAGS.node_features]

DATASET_MAP = {
    'graph_rnn_grid':
    partial(load_grevnet_graph_rnn_dataset,
            'graphs/GraphRNN_RNN_grid_4_128_train_0.dat',
            add_node_features_fn),
    'graph_rnn_protein':
    partial(load_grevnet_graph_rnn_dataset,
            'graphs/GraphRNN_RNN_protein_4_128_train_0.dat',
            add_node_features_fn),
    'graph_rnn_ego':
    partial(load_grevnet_graph_rnn_dataset,
            'graphs/GraphRNN_RNN_citeseer_4_128_train_0.dat',
            add_node_features_fn),
    'graph_rnn_community':
    partial(load_grevnet_graph_rnn_dataset,
            'graphs/GraphRNN_RNN_caveman_4_128_train_0.dat',
            add_node_features_fn),
    'graph_rnn_ego_small':
    partial(load_grevnet_graph_rnn_dataset,
            'graphs/GraphRNN_RNN_citeseer_small_4_64_train_0.dat',
            add_node_features_fn),
    'graph_rnn_community_small':
    partial(load_grevnet_graph_rnn_dataset,
            'graphs/GraphRNN_RNN_caveman_small_4_64_train_0.dat',
            add_node_features_fn),
}

dataset = DATASET_MAP[FLAGS.dataset]()

sess = reset_sess()
saver = tf.train.import_meta_graph("{}.meta".format(FLAGS.checkpoint))
saver.restore(sess, FLAGS.checkpoint)

values_map = {
    'gnn_output': tf.get_collection('gnn_output')[0],
    'num_incorrect': tf.get_collection('num_incorrect')[0]
}

to_pickle = {}
to_pickle['node_features'] = []
to_pickle['n_node'] = []
node_features = to_pickle['node_features']
n_node = to_pickle['n_node']
for i in range(FLAGS.num_train_iters):
    graphs_tuple = dataset.get_next_train_batch(FLAGS.train_batch_size)
    feed_dict = {}
    feed_dict["true_graph_phs/nodes:0"] = graphs_tuple.nodes
    feed_dict["true_graph_phs/edges:0"] = graphs_tuple.edges
    feed_dict["true_graph_phs/receivers:0"] = graphs_tuple.receivers
    feed_dict["true_graph_phs/senders:0"] = graphs_tuple.senders
    feed_dict["true_graph_phs/globals:0"] = graphs_tuple.globals
    feed_dict["true_graph_phs/n_node:0"] = graphs_tuple.n_node
    feed_dict["true_graph_phs/n_edge:0"] = graphs_tuple.n_edge
    feed_dict["is_training:0"] = False

    values = sess.run(values_map, feed_dict=feed_dict)
    node_features.append(values["gnn_output"])
    n_node.append(graphs_tuple.n_node)
    if i % 100 == 0:
        print("iteration num {}".format(i))
        num_incorrect = values['num_incorrect']
        print("num_incorrect {}".format(num_incorrect))

with open("{}_{}.p".format(FLAGS.output_file, FLAGS.run_number), 'wb') as f:
    pickle.dump(to_pickle, f)
