Source code for UCTB.model.STMeta
import keras
import tensorflow as tf
from ..model_unit import BaseModel
from ..model_unit import GAL, GCL
from ..model_unit import DCGRUCell
from ..model_unit import GCLSTMCell
[docs]class STMeta(BaseModel):
"""
Args:
num_node(int): Number of nodes in the graph, e.g. number of stations in NYC-Bike dataset.
external_dim(int): Dimension of the external feature, e.g. temperature and wind are two dimension.
closeness_len(int): The length of closeness data history. The former consecutive ``closeness_len`` time slots
of data will be used as closeness history.
period_len(int): The length of period data history. The data of exact same time slots in former consecutive
``period_len`` days will be used as period history.
trend_len(int): The length of trend data history. The data of exact same time slots in former consecutive
``trend_len`` weeks (every seven days) will be used as trend history.
num_graph(int): Number of graphs used in STMeta.
gcn_k(int): The highest order of Chebyshev Polynomial approximation in GCN.
gcn_layers(int): Number of GCN layers.
gclstm_layers(int): Number of STRNN layers, it works on all modes of STMeta such as GCLSTM and DCRNN.
num_hidden_units(int): Number of hidden units of RNN.
num_dense_units(int): Number of dense units.
graph_merge_gal_units(int): Number of units in GAL for merging different graph features.
Only works when graph_merge='gal'
graph_merge_gal_num_heads(int): Number of heads in GAL for merging different graph features.
Only works when graph_merge='gal'
temporal_merge_gal_units(int): Number of units in GAL for merging different temporal features.
Only works when temporal_merge='gal'
temporal_merge_gal_num_heads(int): Number of heads in GAL for merging different temporal features.
Only works when temporal_merge='gal'
st_method(str): must in ['GCLSTM', 'DCRNN', 'GRU', 'LSTM'], which refers to different
spatial-temporal modeling methods.
'GCLSTM': GCN for modeling spatial feature, LSTM for modeling temporal feature.
'DCRNN': Diffusion Convolution for modeling spatial feature, GRU for modeling temporam frature.
'GRU': Ignore the spatial, and model the temporal feature using GRU
'LSTM': Ignore the spatial, and model the temporal feature using LSTM
temporal_merge(str): must in ['gal', 'concat'], refers to different temporal merging methods,
'gal': merge using GAL,
'concat': merge by concat and dense
graph_merge(str): must in ['gal', 'concat'], refers to different graph merging methods,
'gal': merge using GAL,
'concat': merge by concat and dense
output_activation(function): activation function, e.g. tf.nn.tanh
lr(float): Learning rate. Default: 1e-5
code_version(str): Current version of this model code, which will be used as filename for saving the model
model_dir(str): The directory to store model files. Default:'model_dir'.
gpu_device(str): To specify the GPU to use. Default: '0'.
"""
def __init__(self,
num_node,
external_dim,
closeness_len,
period_len,
trend_len,
# gcn parameters
num_graph=1,
gcn_k=1,
gcn_layers=1,
gclstm_layers=1,
# dense units
num_hidden_units=64,
# LSTM units
num_dense_units=32,
# merge parameters
graph_merge_gal_units=32,
graph_merge_gal_num_heads=2,
temporal_merge_gal_units=64,
temporal_merge_gal_num_heads=2,
# network structure parameters
st_method='GCLSTM', # gclstm
temporal_merge='gal', # gal
graph_merge='gal', # concat
output_activation=tf.nn.sigmoid,
lr=1e-4,
code_version='STMeta-QuickStart',
model_dir='model_dir',
gpu_device='0', **kwargs):
super(STMeta, self).__init__(code_version=code_version, model_dir=model_dir, gpu_device=gpu_device)
self._num_node = num_node
self._gcn_k = gcn_k
self._gcn_layer = gcn_layers
self._graph_merge_gal_units = graph_merge_gal_units
self._graph_merge_gal_num_heads = graph_merge_gal_num_heads
self._temporal_merge_gal_units = temporal_merge_gal_units
self._temporal_merge_gal_num_heads = temporal_merge_gal_num_heads
self._gclstm_layers = gclstm_layers
self._num_graph = num_graph
self._external_dim = external_dim
self._output_activation = output_activation
self._st_method = st_method
self._temporal_merge = temporal_merge
self._graph_merge = graph_merge
self._closeness_len = int(closeness_len)
self._period_len = int(period_len)
self._trend_len = int(trend_len)
self._num_hidden_unit = num_hidden_units
self._num_dense_units = num_dense_units
self._lr = lr
[docs] def build(self, init_vars=True, max_to_keep=5):
with self._graph.as_default():
temporal_features = []
if self._closeness_len is not None and self._closeness_len > 0:
closeness_feature = tf.placeholder(tf.float32, [None, None, self._closeness_len, 1],
name='closeness_feature')
self._input['closeness_feature'] = closeness_feature.name
temporal_features.append([self._closeness_len, closeness_feature, 'closeness_feature'])
if self._period_len is not None and self._period_len > 0:
period_feature = tf.placeholder(tf.float32, [None, None, self._period_len, 1],
name='period_feature')
self._input['period_feature'] = period_feature.name
temporal_features.append([self._period_len, period_feature, 'period_feature'])
if self._trend_len is not None and self._trend_len > 0:
trend_feature = tf.placeholder(tf.float32, [None, None, self._trend_len, 1],
name='trend_feature')
self._input['trend_feature'] = trend_feature.name
temporal_features.append([self._trend_len, trend_feature, 'trend_feature'])
if len(temporal_features) > 0:
target = tf.placeholder(tf.float32, [None, None, 1], name='target')
laplace_matrix = tf.placeholder(tf.float32, [self._num_graph, None, None], name='laplace_matrix')
self._input['target'] = target.name
self._input['laplace_matrix'] = laplace_matrix.name
else:
raise ValueError('closeness_len, period_len, trend_len cannot all be zero')
graph_outputs_list = []
for graph_index in range(self._num_graph):
if self._st_method in ['GCLSTM', 'DCRNN', 'GRU', 'LSTM']:
outputs_temporal = []
for time_step, target_tensor, given_name in temporal_features:
if self._st_method == 'GCLSTM':
multi_layer_cell = tf.keras.layers.StackedRNNCells(
[GCLSTMCell(units=self._num_hidden_unit, num_nodes=self._num_node,
laplacian_matrix=laplace_matrix[graph_index],
gcn_k=self._gcn_k, gcn_l=self._gcn_layer)
for _ in range(self._gclstm_layers)])
outputs = tf.keras.layers.RNN(multi_layer_cell)(tf.reshape(target_tensor, [-1, time_step, 1]))
st_outputs = tf.reshape(outputs, [-1, 1, self._num_hidden_unit])
elif self._st_method == 'DCRNN':
cell = DCGRUCell(self._num_hidden_unit, 1, self._num_graph,
# laplace_matrix will be diffusion_matrix when self._st_method == 'DCRNN'
laplace_matrix,
max_diffusion_step=self._gcn_k,
num_nodes=self._num_node, name=str(graph_index) + given_name)
encoding_cells = [cell] * self._gclstm_layers
encoding_cells = tf.contrib.rnn.MultiRNNCell(encoding_cells, state_is_tuple=True)
inputs_unstack = tf.unstack(tf.reshape(target_tensor, [-1, self._num_node, time_step]),
axis=-1)
outputs, _ = \
tf.contrib.rnn.static_rnn(encoding_cells, inputs_unstack, dtype=tf.float32)
st_outputs = tf.reshape(outputs[-1], [-1, 1, self._num_hidden_unit])
elif self._st_method == 'GRU':
cell = tf.keras.layers.GRUCell(units=self._num_hidden_unit)
multi_layer_gru = tf.keras.layers.StackedRNNCells([cell] * self._gclstm_layers)
outputs = tf.keras.layers.RNN(multi_layer_gru)(
tf.reshape(target_tensor, [-1, time_step, 1]))
st_outputs = tf.reshape(outputs, [-1, 1, self._num_hidden_unit])
elif self._st_method == 'LSTM':
cell = tf.keras.layers.LSTMCell(units=self._num_hidden_unit)
multi_layer_gru = tf.keras.layers.StackedRNNCells([cell] * self._gclstm_layers)
outputs = tf.keras.layers.RNN(multi_layer_gru)(
tf.reshape(target_tensor, [-1, time_step, 1]))
st_outputs = tf.reshape(outputs, [-1, 1, self._num_hidden_unit])
outputs_temporal.append(st_outputs)
if self._temporal_merge == 'concat':
graph_outputs_list.append(tf.concat(outputs_temporal, axis=-1))
elif self._temporal_merge == 'gal':
_, gal_output = GAL.add_ga_layer_matrix(inputs=tf.concat(outputs_temporal, axis=-2),
units=self._temporal_merge_gal_units,
num_head=self._temporal_merge_gal_num_heads)
graph_outputs_list.append(tf.reduce_mean(gal_output, axis=-2, keepdims=True))
if self._num_graph > 1:
if self._graph_merge == 'gal':
# (graph, inputs_name, units, num_head, activation=tf.nn.leaky_relu)
_, gal_output = GAL.add_ga_layer_matrix(inputs=tf.concat(graph_outputs_list, axis=-2),
units=self._graph_merge_gal_units,
num_head=self._graph_merge_gal_num_heads)
dense_inputs = tf.reduce_mean(gal_output, axis=-2, keepdims=True)
elif self._graph_merge == 'concat':
dense_inputs = tf.concat(graph_outputs_list, axis=-1)
else:
dense_inputs = graph_outputs_list[-1]
dense_inputs = tf.reshape(dense_inputs, [-1, self._num_node, 1, dense_inputs.get_shape()[-1].value])
dense_inputs = keras.layers.BatchNormalization(axis=-1, name='feature_map')(dense_inputs)
# external dims
if self._external_dim is not None and self._external_dim > 0:
external_input = tf.placeholder(tf.float32, [None, self._external_dim])
self._input['external_feature'] = external_input.name
external_dense = tf.keras.layers.Dense(units=10)(external_input)
external_dense = tf.tile(tf.reshape(external_dense, [-1, 1, 1, 10]),
[1, tf.shape(dense_inputs)[1], tf.shape(dense_inputs)[2], 1])
dense_inputs = tf.concat([dense_inputs, external_dense], axis=-1)
dense_output0 = tf.keras.layers.Dense(units=self._num_dense_units,
activation=tf.nn.tanh,
use_bias=True,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
kernel_regularizer=tf.keras.regularizers.l2(0.01),
bias_regularizer=tf.keras.regularizers.l2(0.01)
)(dense_inputs)
dense_output1 = tf.keras.layers.Dense(units=self._num_dense_units,
activation=tf.nn.tanh,
use_bias=True,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
kernel_regularizer=tf.keras.regularizers.l2(0.01),
bias_regularizer=tf.keras.regularizers.l2(0.01)
)(dense_output0)
pre_output = tf.keras.layers.Dense(units=1,
activation=tf.nn.tanh,
use_bias=True,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
kernel_regularizer=tf.keras.regularizers.l2(0.01),
bias_regularizer=tf.keras.regularizers.l2(0.01)
)(dense_output1)
prediction = tf.reshape(pre_output, [-1, self._num_node, 1], name='prediction')
loss_pre = tf.sqrt(tf.reduce_mean(tf.square(target - prediction)), name='loss')
train_op = tf.train.AdamOptimizer(self._lr).minimize(loss_pre, name='train_op')
# record output
self._output['prediction'] = prediction.name
self._output['loss'] = loss_pre.name
# record train operation
self._op['train_op'] = train_op.name
super(STMeta, self).build(init_vars, max_to_keep)
# Define your '_get_feed_dict function‘, map your input to the tf-model
def _get_feed_dict(self,
laplace_matrix,
closeness_feature=None,
period_feature=None,
trend_feature=None,
target=None,
external_feature=None):
feed_dict = {
'laplace_matrix': laplace_matrix,
}
if target is not None:
feed_dict['target'] = target
if self._external_dim is not None and self._external_dim > 0:
feed_dict['external_feature'] = external_feature
if self._closeness_len is not None and self._closeness_len > 0:
feed_dict['closeness_feature'] = closeness_feature
if self._period_len is not None and self._period_len > 0:
feed_dict['period_feature'] = period_feature
if self._trend_len is not None and self._trend_len > 0:
feed_dict['trend_feature'] = trend_feature
return feed_dict