from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from math import sqrt

from sonnet.python.modules import base
import graph_nets as gn
import sonnet as snt
import tensorflow as tf
import tensorflow_probability as tfp

from loss import *

fc = tf.contrib.layers.fully_connected
tfb = tfp.bijectors


# Blocks to update a node's embedding based on its neighbors' embeddings.
class GRUBlock(snt.AbstractModule):
    def __init__(self,
                 output_dim,
                 num_layers=2,
                 latent_dim=256,
                 bias_init_stddev=0.01,
                 agg_fn=tf.unsorted_segment_mean,
                 name="GRUBlock"):
        super(GRUBlock, self).__init__(name=name)
        self.received_edges_aggregator = gn.blocks.ReceivedEdgesToNodesAggregator(
            agg_fn)
        self.output_dim = int(output_dim)
        self.num_layers = num_layers
        self.latent_dim = latent_dim
        self.bias_init_stddev = bias_init_stddev

    def _build(self, graph):
        nodes = graph.nodes
        aggn = self.received_edges_aggregator(graph)
        biases_initializer = tf.initializers.truncated_normal(
            self.bias_init_stddev)

        r_t = tf.nn.relu(
            fc(aggn,
               self.latent_dim,
               activation_fn=None,
               biases_initializer=None) +
            fc(nodes,
               self.latent_dim,
               activation_fn=None,
               biases_initializer=biases_initializer))
        for _ in range(self.num_layers):
            r_t = fc(r_t, self.latent_dim, activation_fn=tf.nn.relu)
        r_t = fc(
            r_t,
            self.output_dim,
            activation_fn=tf.sigmoid,
            biases_initializer=biases_initializer)

        z_t = tf.nn.relu(
            fc(aggn,
               self.latent_dim,
               activation_fn=None,
               biases_initializer=None) +
            fc(nodes,
               self.latent_dim,
               activation_fn=None,
               biases_initializer=biases_initializer))
        for _ in range(self.num_layers):
            z_t = fc(
                z_t,
                self.latent_dim,
                activation_fn=tf.nn.relu,
                biases_initializer=biases_initializer)
        z_t = fc(
            z_t,
            self.output_dim,
            activation_fn=tf.sigmoid,
            biases_initializer=biases_initializer)

        g_t = tf.nn.relu(
            fc(aggn,
               self.output_dim,
               activation_fn=None,
               biases_initializer=None) + r_t *
            fc(nodes,
               self.output_dim,
               activation_fn=None,
               biases_initializer=biases_initializer))
        for _ in range(self.num_layers):
            g_t = fc(
                g_t,
                self.latent_dim,
                activation_fn=tf.nn.relu,
                biases_initializer=biases_initializer)
        g_t = fc(
            g_t,
            self.output_dim,
            activation_fn=tf.tanh,
            biases_initializer=biases_initializer)

        new_nodes = (1 - z_t) * g_t + z_t * nodes
        return graph.replace(nodes=new_nodes)


class ConcatThenMLPBlock(snt.AbstractModule):
    def __init__(self, aggn_fn, make_mlp_fn, name="AggThenMLPBlock"):
        super(ConcatThenMLPBlock, self).__init__(name=name)
        self._received_edges_aggregator = gn.blocks.ReceivedEdgesToNodesAggregator(
            aggn_fn)
        self._mlp = make_mlp_fn()

    def _build(self, graph):
        nodes = tf.concat(
            [graph.nodes, self._received_edges_aggregator(graph)], axis=1)
        nodes = self._mlp(nodes)
        return graph.replace(nodes=nodes)


class AggThenMLPBlock(snt.AbstractModule):
    def __init__(self, aggn_fn, make_mlp_fn, epsilon, name="AggThenMLPBlock"):
        super(AggThenMLPBlock, self).__init__(name=name)
        self._received_edges_aggregator = gn.blocks.ReceivedEdgesToNodesAggregator(
            aggn_fn)
        self._mlp = make_mlp_fn()
        self.epsilon = epsilon

    def _build(self, graph):
        nodes = self.epsilon * graph.nodes + self._received_edges_aggregator(
            graph)
        nodes = self._mlp(nodes)
        return graph.replace(nodes=nodes)


# GNN that updates node embeddings based on the neighbor node embeddings.
class IdentityModule(base.AbstractModule):
    def _build(self, inputs):
        return tf.identity(inputs)


EDGE_BLOCK_OPT = {
    "use_edges": False,
    "use_receiver_nodes": False,
    "use_sender_nodes": True,
    "use_globals": False,
}


class NodeBlockGNN(snt.AbstractModule):
    def __init__(self,
                 node_block,
                 edge_block_opt=EDGE_BLOCK_OPT,
                 name="NodeBlockGNN"):
        super(NodeBlockGNN, self).__init__(name=name)

        with self._enter_variable_scope():
            self._edge_block = gn.blocks.EdgeBlock(
                edge_model_fn=IdentityModule, **EDGE_BLOCK_OPT)
            self._node_block = node_block

    def _build(self, graph):
        return self._node_block(self._edge_block(graph))


def make_mlp_model(latent_dimension,
                   input_dimension,
                   num_layers,
                   activation=tf.nn.relu,
                   l2_regularizer_weight=0.01,
                   bias_init_stddev=0.1):
    layers = [latent_dimension] * (num_layers - 1)
    layers.append(input_dimension)
    return snt.Sequential([
        snt.nets.MLP(
            layers,
            activation=activation,
            initializers={
                'w': tf.contrib.layers.xavier_initializer(uniform=True),
                'b': tf.initializers.truncated_normal(stddev=bias_init_stddev)
            },
            regularizers={
                'w': tf.contrib.layers.l2_regularizer(l2_regularizer_weight),
                'b': tf.contrib.layers.l2_regularizer(l2_regularizer_weight)
            },
            activate_final=False),
    ])


class TimestepGNN(snt.AbstractModule):
    """Runs the input GNN for num_processing_steps # of timesteps.
    """

    def __init__(self,
                 make_gnn_fn,
                 num_timesteps,
                 weight_sharing=False,
                 use_batch_norm=False,
                 residual=True,
                 test_local_stats=False,
                 use_layer_norm=False,
                 name="TimestepGNN"):
        super(TimestepGNN, self).__init__(name=name)
        self._weight_sharing = weight_sharing
        self._num_timesteps = num_timesteps
        self._use_batch_norm = use_batch_norm
        self._residual = residual
        self._bns = []
        self._lns = []
        self._test_local_stats = test_local_stats
        self._use_layer_norm = use_layer_norm
        with self._enter_variable_scope():
            if not weight_sharing:
                self._gnn = [make_gnn_fn() for _ in range(num_timesteps)]
            else:
                self._gnn = make_gnn_fn()
            if use_batch_norm:
                self._bns = [
                    snt.BatchNorm(scale=True) for _ in range(num_timesteps)
                ]
            if use_layer_norm:
                self._lns = [snt.LayerNorm() for _ in range(num_timesteps)]

    def _build(self, input_op, is_training):
        output = input_op
        for i in range(self._num_timesteps):
            if self._use_batch_norm:
                norm_nodes = self._bns[i](
                    output.nodes,
                    is_training=is_training,
                    test_local_stats=self._test_local_stats)
                output = output.replace(nodes=norm_nodes)
            if self._use_layer_norm:
                norm_nodes = self._lns[i](output.nodes)
                output = output.replace(nodes=norm_nodes)
            if not self._weight_sharing:
                output = self._gnn[i](output)
            else:
                output = self._gnn(output)
        if self._residual:
            output = output.replace(nodes=output.nodes + input_op.nodes)
        return output


def avg_then_mlp_gnn(make_mlp_fn, epsilon):
    avg_then_mlp_block = AggThenMLPBlock(tf.unsorted_segment_mean, make_mlp_fn,
                                         epsilon)
    return NodeBlockGNN(avg_then_mlp_block)


def sum_then_mlp_gnn(make_mlp_fn, epsilon):
    sum_then_mlp_block = AggThenMLPBlock(tf.unsorted_segment_sum, make_mlp_fn,
                                         epsilon)
    return NodeBlockGNN(sum_then_mlp_block)


def sum_concat_then_mlp_gnn(make_mlp_fn):
    node_block = ConcatThenMLPBlock(tf.unsorted_segment_sum, make_mlp_fn)
    return NodeBlockGNN(node_block)


def avg_concat_then_mlp_gnn(make_mlp_fn):
    node_block = ConcatThenMLPBlock(tf.unsorted_segment_mean, make_mlp_fn)
    return NodeBlockGNN(node_block)


def make_batch_norm():
    bn = tf.layers.BatchNormalization(
        axis=-1, gamma_constraint=lambda x: tf.nn.relu(x) + 1e-6)
    return tfb.BatchNormalization(batchnorm_layer=bn, training=True)


#class GRevNetMasking(snt.AbstractModule):
#    def __init__(self,
#                 make_gnn_fn,
#                 num_timesteps,
#                 node_embedding_dim,
#                 use_batch_norm=False,
#                 name="GRevNetMasking"):
#        super(GRevNetMasking, self).__init__(name=name)
#        self.s = [make_gnn_fn() for _ in range(2 * num_timesteps)]
#        self.t = [make_gnn_fn() for _ in range(2 * num_timesteps)]
#        self.masks = []
#        dim = int(node_embedding_dim / 2)
#        for i in range(num_timesteps):
#            self.masks.append(
#                tf.constant([0.] * dim + [1.] * dim, dtype=tf.float32))
#            self.masks.append(
#                tf.constant([1.] * dim + [0.] * dim, dtype=tf.float32))
#        self.use_batch_norm = use_batch_norm
#        self.bns = [make_batch_norm() for _ in range(2 * num_timesteps)]
#
#    def f(self, x):
#        log_det_jacobian = 0
#        for i in range(len(self.t)):
#            mask = self.masks[i]
#            x_ = x.replace(nodes=x.nodes * mask)
#            if self.use_batch_norm:
#                bn = self.bns[i]
#                log_det_jacobian += bn.inverse_log_det_jacobian(x_.nodes, 2)
#                x_ = x.replace(nodes=bn.inverse(x_.nodes))
#            s = self.s[i](x_).nodes * (1 - mask)
#            t = self.t[i](x_).nodes * (1 - mask)
#            log_det_jacobian += tf.reduce_sum(s)
#            updated_nodes = x_.nodes + (1 - mask) * (x.nodes * tf.exp(s) + t)
#            x = x.replace(nodes=updated_nodes)
#        return x, log_det_jacobian
#
#    def g(self, z):
#        for i in reversed(range(len(self.t))):
#            mask = self.masks[i]
#            z_ = z.replace(nodes=z.nodes * mask)
#            s = self.s[i](z_).nodes * (1 - mask)
#            t = self.t[i](z_).nodes * (1 - mask)
#            if self.use_batch_norm:
#                bn = self.bns[i]
#                z_ = z_.replace(nodes=bn.forward(z_.nodes) * mask)
#            updated_nodes = (1 - mask) * (z.nodes - t) * tf.exp(-s) + z_.nodes
#            z = z.replace(nodes=updated_nodes)
#        return z
#
#    def log_prob(self, x):
#        z, log_det_jacobian = self.f(x)
#        return tf.reduce_sum(self.prior.log_prob(z)) + log_det_jacobian
#
#    def _build(self, input, inverse=True):
#        func = self.f if inverse else self.g
#        return func(input)


def get_gnns(num_timesteps, make_gnn_fn):
    return [make_gnn_fn() for _ in range(num_timesteps)]

class GRevNet(snt.AbstractModule):
    def __init__(self,
                 make_gnn_fn,
                 num_timesteps,
                 node_embedding_dim,
                 use_batch_norm=False,
                 weight_sharing=False,
                 name="GRevNet"):
        super(GRevNet, self).__init__(name=name)
        self.num_timesteps = num_timesteps
        self.weight_sharing = weight_sharing
        if weight_sharing:
            self.s = [make_gnn_fn(), make_gnn_fn()]
            self.t = [make_gnn_fn(), make_gnn_fn()]
        else:
            self.s = [
                get_gnns(num_timesteps, make_gnn_fn),
                get_gnns(num_timesteps, make_gnn_fn)
            ]
            self.t = [
                get_gnns(num_timesteps, make_gnn_fn),
                get_gnns(num_timesteps, make_gnn_fn)
            ]
        self.use_batch_norm = use_batch_norm
        self.bns = [[make_batch_norm() for _ in range(num_timesteps)],
                    [make_batch_norm() for _ in range(num_timesteps)]]

    def f(self, x):
        log_det_jacobian = 0
        x0, x1 = tf.split(x.nodes, num_or_size_splits=2, axis=1)
        x0 = x.replace(nodes=x0)
        x1 = x.replace(nodes=x1)
        for i in range(self.num_timesteps):
            if self.use_batch_norm:
                bn = self.bns[0][i]
                log_det_jacobian += bn.inverse_log_det_jacobian(x0.nodes, 2)
                x0 = x0.replace(nodes=bn.inverse(x0.nodes))
            if self.weight_sharing:
                s = self.s[0](x0).nodes
                t = self.t[0](x0).nodes
            else:
                s = self.s[0][i](x0).nodes
                t = self.t[0][i](x0).nodes
            log_det_jacobian += tf.reduce_sum(s)
            x1 = x1.replace(nodes=x1.nodes * tf.exp(s) + t)

            if self.use_batch_norm:
                bn = self.bns[1][i]
                log_det_jacobian += bn.inverse_log_det_jacobian(x1.nodes, 2)
                x1 = x1.replace(nodes=bn.inverse(x1.nodes))
            if self.weight_sharing:
                s = self.s[1](x1).nodes
                t = self.t[1](x1).nodes
            else:
                s = self.s[1][i](x1).nodes
                t = self.t[1][i](x1).nodes
            log_det_jacobian += tf.reduce_sum(s)
            x0 = x0.replace(nodes=x0.nodes * tf.exp(s) + t)

        x = x.replace(nodes=tf.concat([x0.nodes, x1.nodes], axis=1))
        return x, log_det_jacobian

    def g(self, z):
        z0, z1 = tf.split(z.nodes, num_or_size_splits=2, axis=1)
        z0 = z.replace(nodes=z0)
        z1 = z.replace(nodes=z1)
        for i in reversed(range(self.num_timesteps)):
            if self.weight_sharing:
                s = self.s[1](z1).nodes
                t = self.t[1](z1).nodes
            else:
                s = self.s[1][i](z1).nodes
                t = self.t[1][i](z1).nodes
            if self.use_batch_norm:
                bn = self.bns[1][i]
                z1 = z1.replace(nodes=bn.forward(z1.nodes))
            z0 = z0.replace(nodes=(z0.nodes - t) * tf.exp(-s))

            if self.weight_sharing:
                s = self.s[0](z0).nodes
                t = self.t[0](z0).nodes
            else:
                s = self.s[0][i](z0).nodes
                t = self.t[0][i](z0).nodes
            if self.use_batch_norm:
                bn = self.bns[0][i]
                z0 = z0.replace(nodes=bn.forward(z0.nodes))
            z1 = z1.replace(nodes=(z1.nodes - t) * tf.exp(-s))
        return z.replace(nodes=tf.concat([z0.nodes, z1.nodes], axis=1))

    def log_prob(self, x):
        z, log_det_jacobian = self.f(x)
        return tf.reduce_sum(self.prior.log_prob(z)) + log_det_jacobian

    def _build(self, input, inverse=True):
        func = self.f if inverse else self.g
        return func(input)



class SelfAttention(snt.AbstractModule):
    def __init__(self,
                 kv_dim,
                 output_dim,
                 make_mlp_fn,
                 batch_size,
                 num_heads=8,
                 multi_proj_dim=20,
                 concat=True,
                 residual=False,
                 layer_norm=False,
                 name="faster_self_attention"):
        super(SelfAttention, self).__init__(name=name)
        self.kv_dim = kv_dim
        self.output_dim = output_dim
        self.mlp = make_mlp_fn()
        self.batch_size = batch_size
        self.num_heads = num_heads
        self.multi_proj_dim = multi_proj_dim
        self.concat = concat
        self.residual = residual
        self.layer_norm = layer_norm

    def _build(self, graph):
        initializers = {
            'w': tf.contrib.layers.xavier_initializer(uniform=True),
        }

        # [batch_size, num_heads * kv_dim].
        project_q_mod = snt.Linear(
            self.num_heads * self.kv_dim,
            use_bias=False,
            initializers=initializers)
        project_q = project_q_mod(graph.nodes)
        project_k_mod = snt.Linear(
            self.num_heads * self.kv_dim,
            use_bias=False,
            initializers=initializers)
        project_k = project_k_mod(graph.nodes)

        # At the end of this block, project_q_mod and project_k_mod are both
        # [batch_size, num_heads, kv_dim].
        project_q = tf.reshape(project_q, [-1, self.num_heads, self.kv_dim])
        project_q = tf.transpose(project_q, perm=[0, 1, 2])
        project_k = tf.reshape(project_k, [-1, self.num_heads, self.kv_dim])
        project_k = tf.transpose(project_k, perm=[0, 1, 2])

        # At the end of this block, project_v is [batch_size, num_heads,
        # output_dim].
        project_v_mod = snt.Linear(
            self.output_dim, use_bias=False, initializers=initializers)
        project_v = project_v_mod(graph.nodes)
        project_v = tf.keras.backend.repeat(project_v, self.num_heads)

        attn_module = gn.modules.SelfAttention()
        attn_graph = attn_module(project_v, project_q, project_k, graph)

        # [batch_size, num_heads, output_dim].
        new_nodes = attn_graph.nodes

        new_nodes = tf.transpose(new_nodes, perm=[0, 1, 2])
        new_nodes = tf.reshape(new_nodes,
                               [-1, self.num_heads * self.output_dim])

        # At this point, new_nodes is [batch_size, num_heads * output_dim].
        new_node_proj = snt.Linear(self.multi_proj_dim, use_bias=False)
        new_nodes = new_node_proj(new_nodes)

        if self.concat:
            new_nodes = tf.concat([graph.nodes, new_nodes], axis=1)
        new_nodes = self.mlp(new_nodes)

        if self.residual:
            new_nodes += graph.nodes

        if self.layer_norm:
            ln_mod = snt.LayerNorm()
            new_nodes = ln_mod(new_nodes)
        return graph.replace(nodes=new_nodes)


def self_attn_gnn(kv_dim,
                  output_dim,
                  make_mlp_fn,
                  batch_size,
                  num_heads,
                  multi_proj_dim,
                  concat=True,
                  residual=False,
                  layer_norm=False):
    return SelfAttention(kv_dim, output_dim, make_mlp_fn, batch_size,
                         num_heads, multi_proj_dim, concat, residual,
                         layer_norm)


class MySelfAttention(snt.AbstractModule):
    def __init__(self,
                 kq_dim,
                 output_dim,
                 make_mlp_fn,
                 name="my_self_attention"):
        super(MySelfAttention, self).__init__(name=name)
        self.kq_dim = kq_dim
        self.output_dim = output_dim
        self.mlp = make_mlp_fn()

    def _build(self, graph):
        project_q_mod = snt.Linear(self.kq_dim, use_bias=False)
        project_k_mod = snt.Linear(self.kq_dim, use_bias=False)
        project_v_mod = snt.Linear(self.output_dim, use_bias=False)

        project_q = project_q_mod(graph.nodes)
        project_k = project_k_mod(graph.nodes)
        project_v = project_v_mod(graph.nodes)

        logits = tf.matmul(project_q, tf.transpose(project_k)) / sqrt(self.kq_dim)
        mask = loss_mask(graph)
        mask = tf.where(mask > 0, mask, tf.zeros_like(mask) - 100)
        logits *= mask
        attn_weights = tf.nn.softmax(logits, axis=-1)

        attended_nodes = tf.matmul(attn_weights, project_v)
        concat_nodes = tf.concat([graph.nodes, attended_nodes], axis=-1)
        new_nodes = self.mlp(concat_nodes)
        return graph.replace(nodes=new_nodes)

def my_self_attn_gnn(kq_dim,
                     output_dim,
                     make_mlp_fn):
    return MySelfAttention(kq_dim, output_dim, make_mlp_fn)

#class FasterSelfAttentionGN(snt.AbstractModule):
#    def __init__(self,
#                 kv_dim,
#                 output_dim,
#                 make_mlp_fn,
#                 batch_size,
#                 residual=False,
#                 layer_norm=False,
#                 name="faster_self_attention"):
#        super(FasterSelfAttentionGN, self).__init__(name=name)
#        self.kv_dim = kv_dim
#        self.output_dim = output_dim
#        self.mlp = make_mlp_fn()
#        self.batch_size = batch_size
#        self.residual = residual
#        self.layer_norm = layer_norm
#
#    def _build(self, graph):
#        initializers = {
#            'w': tf.contrib.layers.xavier_initializer(uniform=True),
#        }
#        project_q_mod = snt.Linear(
#            self.kv_dim, use_bias=False, initializers=initializers)
#        project_k_mod = snt.Linear(
#            self.kv_dim, use_bias=False, initializers=initializers)
#        project_v_mod = snt.Linear(
#            self.output_dim, use_bias=False, initializers=initializers)
#
#        project_q = tf.expand_dims(project_q_mod(graph.nodes), axis=1)
#        project_k = tf.expand_dims(project_k_mod(graph.nodes), axis=1)
#        project_v = tf.expand_dims(project_v_mod(graph.nodes), axis=1)
#
#        attn_module = gn.modules.SelfAttention()
#        attn_graph = attn_module(project_v, project_q, project_k, graph)
#        new_nodes = attn_graph.nodes[:, 0, :]
#        new_nodes = tf.concat([graph.nodes, new_nodes], axis=1)
#        new_nodes = self.mlp(new_nodes)
#        return graph.replace(nodes=new_nodes)
#
#
#def self_attn_gnn(kv_dim,
#                  output_dim,
#                  make_mlp_fn,
#                  batch_size,
#                  residual=False,
#                  layer_norm=False):
#    return FasterSelfAttentionGN(kv_dim, output_dim, make_mlp_fn, batch_size,
#                                 residual, layer_norm)
#
#def self_attn_gnn(kv_dim,
#                  output_dim,
#                  make_mlp_fn,
#                  batch_size,
#                  residual=False,
#                  layer_norm=False):
#    return FasterSelfAttention(kv_dim, output_dim, make_mlp_fn, batch_size,
#                               residual, layer_norm)
#class FasterSelfAttention(snt.AbstractModule):
#    def __init__(self,
#                 kv_dim,
#                 output_dim,
#                 make_mlp_fn,
#                 batch_size,
#                 residual=False,
#                 layer_norm=False,
#                 name="faster_self_attention"):
#        super(FasterSelfAttention, self).__init__(name=name)
#        self.kv_dim = kv_dim
#        self.output_dim = output_dim
#        self.mlp = make_mlp_fn()
#        self.batch_size = batch_size
#        self.residual = residual
#        self.layer_norm = layer_norm
#
#    def _build(self, graph):
#        dim = tf.shape(graph.nodes)[-1]
#
#        project_q_mod = snt.Linear(self.kv_dim, use_bias=False)
#        project_k_mod = snt.Linear(self.kv_dim, use_bias=False)
#        project_v_mod = snt.Linear(self.output_dim, use_bias=False)
#
#        project_q = project_q_mod(graph.nodes)
#        project_k = project_k_mod(graph.nodes)
#        project_v = project_v_mod(graph.nodes)
#
#        logits = tf.matmul(project_q, tf.transpose(project_k)) / sqrt(
#            self.kv_dim)
#        attn_weights = tf.nn.softmax(logits, axis=-1)
#        mask = loss_mask(graph.n_node, self.batch_size)
#        attn_weights *= mask
#        attended_nodes = tf.matmul(attn_weights, project_v)
#
#        concat_nodes = tf.concat([graph.nodes, attended_nodes], axis=-1)
#        nodes = self.mlp(concat_nodes)
#        if self.residual:
#            nodes += graph.nodes
#
#        if self.layer_norm:
#            ln_mod = snt.LayerNorm()
#            nodes = ln_mod(nodes)
#
#        return graph.replace(nodes=nodes)
#
#
#class FasterSelfAttentionGNMultiHead(snt.AbstractModule):
#    def __init__(self,
#                 kq_dim,
#                 output_dim,
#                 make_mlp_fn,
#                 batch_size,
#                 residual=False,
#                 layer_norm=False,
#                 num_heads=1,
#                 name="faster_self_attention"):
#        super(FasterSelfAttentionGN, self).__init__(name=name)
#        self.kv_dim = kv_dim
#        self.output_dim = output_dim
#        self.mlp = make_mlp_fn()
#        self.batch_size = batch_size
#        self.residual = residual
#        self.layer_norm = layer_norm
#        self.num_heads = num_heads
#        self.project_vs = []
#        self.project_qs = []
#        self.project_ks = []
#        initializers = {
#            'w': tf.contrib.layers.xavier_initializer(uniform=True),
#        }
#
#    def _build(self, graph):
#        initializers = {
#            'w': tf.contrib.layers.xavier_initializer(uniform=True),
#        }
#        project_q_mod = snt.Linear(
#            self.kv_dim, use_bias=False, initializers=initializers)
#        project_k_mod = snt.Linear(
#            self.kv_dim, use_bias=False, initializers=initializers)
#        project_v_mod = snt.Linear(
#            self.output_dim, use_bias=False, initializers=initializers)
#
#        project_q = tf.expand_dims(project_q_mod(graph.nodes), axis=1)
#        project_k = tf.expand_dims(project_k_mod(graph.nodes), axis=1)
#        project_v = tf.expand_dims(project_v_mod(graph.nodes), axis=1)
#
#        attn_module = gn.modules.SelfAttention()
#        attn_graph = attn_module(project_v, project_q, project_k, graph)
#        new_nodes = attn_graph.nodes[:, 0, :]
#        new_nodes = tf.concat([graph.nodes, new_nodes], axis=1)
#        new_nodes = self.mlp(new_nodes)
#        return graph.replace(nodes=new_nodes)
