import tensorflow as tf
from tensorflow.python.framework import dtypes
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops, linalg_ops, math_ops
def _generate_dropout_mask(ones, rate, training=None, count=1):
def dropped_inputs():
return K.dropout(ones, rate)
if count > 1:
return [
K.in_train_phase(dropped_inputs, ones, training=training)
for _ in range(count)
]
return K.in_train_phase(dropped_inputs, ones, training=training)
[docs]class GCLSTMCell(tf.keras.layers.LSTMCell):
def __init__(self, units, num_nodes, laplacian_matrix, gcn_k=1, gcn_l=1, **kwargs):
super().__init__(units, **kwargs)
self._units = units
self._num_node = num_nodes
self._gcn_k = gcn_k
self._gcn_l = gcn_l
self._laplacian_matrix = laplacian_matrix
@tf_utils.shape_type_conversion
def build(self, input_shape):
super(GCLSTMCell, self).build(input_shape)
input_dim = input_shape[-1]
self.kernel = self.add_weight(
shape=(input_dim * (self._gcn_k + 1), self.units * 4),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
self.recurrent_kernel = self.add_weight(
shape=(self.units * (self._gcn_k + 1), self.units * 4),
name='recurrent_kernel',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)
[docs] def kth_cheby_ploy(self, k, tk1=None, tk2=None):
if k == 0:
return linalg_ops.eye(self._num_node, dtype=dtypes.float32)
elif k == 1:
return self._laplacian_matrix
elif k > 1:
return math_ops.matmul(2 * self._laplacian_matrix, tk1) - tk2
[docs] def call(self, inputs, states, training=None):
if 0 < self.dropout < 1 and self._dropout_mask is None:
self._dropout_mask = _generate_dropout_mask(
array_ops.ones_like(inputs),
self.dropout,
training=training,
count=4)
if (0 < self.recurrent_dropout < 1 and
self._recurrent_dropout_mask is None):
self._recurrent_dropout_mask = _generate_dropout_mask(
array_ops.ones_like(states[0]),
self.recurrent_dropout,
training=training,
count=4)
input_dim = inputs.get_shape()[-1].value
# dropout matrices for input units
dp_mask = self._dropout_mask
# dropout matrices for recurrent units
rec_dp_mask = self._recurrent_dropout_mask
h_tm1 = states[0] # previous memory state
c_tm1 = states[1] # previous carry state
if 0. < self.dropout < 1.:
inputs *= dp_mask[0]
if 0. < self.recurrent_dropout < 1.:
h_tm1 *= rec_dp_mask[0]
# inputs has shape: [batch * num_nodes, input_dim]
# h_tm1 has shape: [batch * num_nodes, units]
inputs_before_gcn = tf.reshape(tf.transpose(tf.reshape(inputs, [-1, self._num_node, input_dim]),
[1, 0, 2]), [self._num_node, -1])
h_tm1_before_gcn = tf.reshape(tf.transpose(tf.reshape(h_tm1, [-1, self._num_node, self._units]),
[1, 0, 2]), [self._num_node, -1])
t = []
inputs_after_gcn = list()
h_tm1_after_gcn = list()
for i in range(0, self._gcn_k + 1):
t.append(self.kth_cheby_ploy(k=i, tk1=None if i < 1 else t[i - 1], tk2=None if i < 2 else t[i - 2]))
inputs_after_gcn.append(tf.matmul(t[-1], inputs_before_gcn))
h_tm1_after_gcn.append(tf.matmul(t[-1], h_tm1_before_gcn))
inputs_after_gcn = tf.reshape(inputs_after_gcn, [self._gcn_k + 1, self._num_node, -1, input_dim])
h_tm1_after_gcn = tf.reshape(h_tm1_after_gcn, [self._gcn_k + 1, self._num_node, -1, self._units])
inputs_after_gcn = tf.reshape(tf.transpose(inputs_after_gcn, [2, 1, 0, 3]), [-1, (self._gcn_k + 1) * input_dim])
h_tm1_after_gcn = tf.reshape(tf.transpose(h_tm1_after_gcn, [2, 1, 0, 3]), [-1, (self._gcn_k + 1) * self.units])
z = K.dot(inputs_after_gcn, self.kernel)
z += K.dot(h_tm1_after_gcn, self.recurrent_kernel)
if self.use_bias:
z = K.bias_add(z, self.bias)
z0 = z[:, :self.units]
z1 = z[:, self.units:2 * self.units]
z2 = z[:, 2 * self.units:3 * self.units]
z3 = z[:, 3 * self.units:]
z = (z0, z1, z2, z3)
c, o = self._compute_carry_and_output_fused(z, c_tm1)
h = o * self.activation(c)
return h, [h, c]