Source code for UCTB.model_unit.DCRNN_CELL

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf

from tensorflow.contrib.rnn import RNNCell


[docs]class DCGRUCell(RNNCell): """Graph Convolution Gated Recurrent Unit cell. """
[docs] def call(self, inputs, **kwargs): pass
[docs] def compute_output_shape(self, input_shape): pass
def __init__(self, num_units, input_dim, num_graphs, supports, max_diffusion_step, num_nodes, num_proj=None, activation=tf.nn.tanh, reuse=None, use_gc_for_ru=True, name=None): """ :param num_units: :param adj_mx: :param max_diffusion_step: :param num_nodes: :param input_size: :param num_proj: :param activation: :param reuse: :param filter_type: "laplacian", "random_walk", "dual_random_walk". :param use_gc_for_ru: whether to use Graph convolution to calculate the reset and update gates. """ super(DCGRUCell, self).__init__(_reuse=reuse) self._activation = activation self._num_nodes = num_nodes self._input_dim = input_dim self._num_graphs = num_graphs self._num_proj = num_proj self._num_units = num_units self._max_diffusion_step = max_diffusion_step self._use_gc_for_ru = use_gc_for_ru self._supports = supports self._num_diff_matrix = supports.get_shape()[0].value self._name = name @property def state_size(self): return self._num_nodes * self._num_units @property def output_size(self): output_size = self._num_nodes * self._num_units if self._num_proj is not None: output_size = self._num_nodes * self._num_proj return output_size def __call__(self, inputs, state, scope=None): """Gated recurrent unit (GRU) with Graph Convolution. :param inputs: (B, num_nodes * input_dim) :return - Output: A `2-D` tensor with shape `[batch_size x self.output_size]`. - New state: Either a single `2-D` tensor, or a tuple of tensors matching the arity and shapes of `state` """ with tf.variable_scope(scope or "dcgru_cell"): with tf.variable_scope("gates"): # Reset gate and update gate. output_size = 2 * self._num_units # We start with bias of 1.0 to not reset and not update. if self._use_gc_for_ru: fn = self._gconv else: fn = self._fc value = tf.nn.sigmoid(fn(inputs, state, output_size, bias_start=1.0)) value = tf.reshape(value, (-1, self._num_nodes, output_size)) r, u = tf.split(value=value, num_or_size_splits=2, axis=-1) r = tf.reshape(r, (-1, self._num_nodes * self._num_units)) u = tf.reshape(u, (-1, self._num_nodes * self._num_units)) with tf.variable_scope("candidate"): c = self._gconv(inputs, r * state, self._num_units) if self._activation is not None: c = self._activation(c) output = new_state = u * state + (1 - u) * c if self._num_proj is not None: with tf.variable_scope("projection"): w = tf.get_variable('w', shape=(self._num_units, self._num_proj)) output = tf.reshape(new_state, shape=(-1, self._num_units)) output = tf.reshape(tf.matmul(output, w), shape=(-1, self.output_size)) return output, new_state @staticmethod def _concat(x, x_): x_ = tf.expand_dims(x_, 0) return tf.concat([x, x_], axis=0) def _fc(self, inputs, state, output_size, bias_start=0.0): dtype = inputs.dtype batch_size = inputs.get_shape()[0].value inputs = tf.reshape(inputs, (batch_size * self._num_nodes, -1)) state = tf.reshape(state, (batch_size * self._num_nodes, -1)) inputs_and_state = tf.concat([inputs, state], axis=-1) input_size = inputs_and_state.get_shape()[-1].value weights = tf.get_variable( 'weights', [input_size, output_size], dtype=dtype, initializer=tf.contrib.layers.xavier_initializer()) value = tf.nn.sigmoid(tf.matmul(inputs_and_state, weights)) biases = tf.get_variable("biases", [output_size], dtype=dtype, initializer=tf.constant_initializer(bias_start, dtype=dtype)) value = tf.nn.bias_add(value, biases) return value def _gconv(self, inputs, state, output_size, bias_start=0.0): """Graph convolution between input and the graph matrix. :param args: a 2D Tensor or a list of 2D, batch x n, Tensors. :param output_size: :param bias: :param bias_start: :param scope: :return: """ # Reshape input and state to (batch_size, num_nodes, input_dim/state_dim) last_dim = inputs.get_shape()[-1].value inputs = tf.reshape(inputs, (-1, self._num_nodes, int(last_dim / self._num_nodes))) state = tf.reshape(state, (-1, self._num_nodes, self._num_units)) inputs_and_state = tf.concat([inputs, state], axis=2) input_size = inputs_and_state.get_shape()[2].value dtype = inputs.dtype x = inputs_and_state x0 = tf.transpose(x, perm=[1, 2, 0]) # (num_nodes, total_arg_size, batch_size) x0 = tf.reshape(x0, shape=[self._num_nodes, -1]) x = tf.expand_dims(x0, axis=0) scope = tf.get_variable_scope() with tf.variable_scope(scope.name + (self.name or ''), reuse=False): if self._max_diffusion_step == 0: pass else: for index in range(self._num_diff_matrix): x1 = tf.matmul(self._supports[index], x0) x = self._concat(x, x1) for k in range(2, self._max_diffusion_step + 1): x2 = 2 * tf.matmul(self._supports[index], x1) - x0 x = self._concat(x, x2) x1, x0 = x2, x1 num_matrices = self._num_diff_matrix * self._max_diffusion_step + 1 # Adds for x itself. x = tf.reshape(x, shape=[num_matrices, self._num_nodes, input_size, -1]) x = tf.transpose(x, perm=[3, 1, 2, 0]) # (batch_size, num_nodes, input_size, order) x = tf.reshape(x, shape=[-1, input_size * num_matrices]) weights = tf.get_variable( 'weights', [input_size * num_matrices, output_size], dtype=dtype, initializer=tf.contrib.layers.xavier_initializer()) x = tf.matmul(x, weights) # (batch_size * self._num_nodes, output_size) biases = tf.get_variable("biases", [output_size], dtype=dtype, initializer=tf.constant_initializer(bias_start, dtype=dtype)) x = tf.nn.bias_add(x, biases) # Reshape res back to 2D: (batch_size, num_node, state_dim) -> (batch_size, num_node * state_dim) return tf.reshape(x, [-1, self._num_nodes * output_size])