import tensorflow as tf
from ..model_unit import BaseModel
[docs]class ST_ResNet(BaseModel):
"""ST-ResNet is a deep-learning model with an end-to-end structure
based on unique properties of spatio-temporal data making use of convolution and residual units.
Reference: `Deep Spatio-Temporal Residual Networks for Citywide Crowd Flows Prediction (Junbo Zhang et al., 2016)
<https://arxiv.org/pdf/1610.00081.pdf>`_.
Args:
width (int): The width of grid data.
height (int): The height of grid data.
externai_dim (int): Number of dimensions of external data.
closeness_len (int): The length of closeness data history. The former consecutive ``closeness_len`` time slots
of data will be used as closeness history.
period_len (int): The length of period data history. The data of exact same time slots in former consecutive
``period_len`` days will be used as period history.
trend_len (int): The length of trend data history. The data of exact same time slots in former consecutive
``trend_len`` weeks (every seven days) will be used as trend history.
num_residual_unit (int): Number of residual units. Default: 4
kernel_size (int): Kernel size in Convolutional neural networks. Default: 3
lr (float): Learning rate. Default: 1e-5
code_version (str): Current version of this model code.
model_dir (str): The directory to store model files. Default:'model_dir'
conv_filters (int): the Number of filters in the convolution. Default: 64
gpu_device (str): To specify the GPU to use. Default: '0'
"""
def __init__(self,
width,
height,
external_dim,
closeness_len,
period_len,
trend_len,
num_residual_unit=4,
kernel_size=3,
lr=5e-5,
model_dir='model_dir',
code_version='QuickStart',
conv_filters=64,
gpu_device='0'):
super(ST_ResNet, self).__init__(code_version=code_version, model_dir=model_dir, gpu_device=gpu_device)
self._width = width
self._height = height
self._closeness_len = closeness_len
self._period_len = period_len
self._trend_len = trend_len
self._conv_filters = conv_filters
self._kernel_size = kernel_size
self._external_dim = external_dim
self._num_residual_unit = num_residual_unit
self._lr = lr
[docs] def build(self):
with self._graph.as_default():
target_conf = []
if self._closeness_len is not None and self._closeness_len > 0:
c_conf = tf.placeholder(tf.float32, [None, self._height, self._width, self._closeness_len], name='c')
self._input['closeness_feature'] = c_conf.name
target_conf.append(c_conf)
if self._period_len is not None and self._period_len > 0:
p_conf = tf.placeholder(tf.float32, [None, self._height, self._width, self._period_len], name='p')
self._input['period_feature'] = p_conf.name
target_conf.append(p_conf)
if self._trend_len is not None and self._trend_len > 0:
t_conf = tf.placeholder(tf.float32, [None, self._height, self._width, self._trend_len], name='t')
self._input['trend_feature'] = t_conf.name
target_conf.append(t_conf)
target = tf.placeholder(tf.float32, [None, self._height, self._width, 1], name='target')
self._input['target'] = target.name
outputs = []
for conf in target_conf:
residual_input = tf.layers.conv2d(conf, filters=self._conv_filters,
kernel_size=[self._kernel_size, self._kernel_size],
padding='SAME', activation=tf.nn.relu)
def residual_unit(x):
residual_output = tf.nn.relu(x)
residual_output = tf.layers.conv2d(residual_output, filters=self._conv_filters,
kernel_size=[self._kernel_size, self._kernel_size], padding='SAME')
residual_output = tf.nn.relu(residual_output)
residual_output = tf.layers.conv2d(residual_output, filters=self._conv_filters,
kernel_size=[self._kernel_size, self._kernel_size], padding='SAME')
return residual_output + x
for i in range(self._num_residual_unit):
residual_input = residual_unit(residual_input)
outputs.append(tf.layers.conv2d(tf.nn.relu(residual_input), filters=self._conv_filters,
kernel_size=[self._kernel_size, self._kernel_size], padding='SAME'))
if len(outputs) == 1:
x = outputs[0]
else:
fusion_weight = tf.Variable(tf.random_normal([len(outputs), ]))
for i in range(len(outputs)):
outputs[i] = fusion_weight[i] * outputs[i]
x = tf.reduce_sum(outputs, axis=0)
# external dims
if self._external_dim is not None and self._external_dim > 0:
external_input = tf.placeholder(tf.float32, [None, self._external_dim])
self._input['external_feature'] = external_input.name
external_dense = tf.layers.dense(inputs=external_input, units=10)
external_dense = tf.tile(tf.reshape(external_dense, [-1, 1, 1, 10]),
[1, self._height, self._width, 1])
x = tf.concat([x, external_dense], axis=-1)
x = tf.layers.dense(x, units=1, name='prediction', activation=tf.nn.sigmoid)
loss = tf.sqrt(tf.reduce_mean(tf.square(x - target)), name='loss')
train_op = tf.train.AdamOptimizer(self._lr).minimize(loss)
self._output['prediction'] = x.name
self._output['loss'] = loss.name
self._op['train_op'] = train_op.name
super(ST_ResNet, self).build()
def _get_feed_dict(self, closeness_feature=None, period_feature=None, trend_feature=None,
target=None, external_feature=None):
'''
The method to get feet dict for tensorflow model.
Users may modify this method according to the format of input.
Args:
closeness_feature (np.ndarray or ``None``): The closeness history data.
If type is np.ndarray, its shape is [time_slot_num, height, width, closeness_len].
period_feature (np.ndarray or ``None``): The period history data.
If type is np.ndarray, its shape is [time_slot_num, height, width, period_len].
trend_feature (np.ndarray or ``None``): The trend history data.
If type is np.ndarray, its shape is [time_slot_num, height, width, trend_len].
target (np.ndarray or ``None``): The target value data.
If type is np.ndarray, its shape is [time_slot_num, height, width, 1].
external_feature (np.ndarray or ``None``): The external feature data.
If type is np.ndaaray, its shape is [time_slot_num, feature_num].
'''
feed_dict = {}
if target is not None:
feed_dict['target'] = target
if self._external_dim is not None and self._external_dim > 0:
feed_dict['external_feature'] = external_feature
if self._closeness_len is not None and self._closeness_len > 0:
feed_dict['closeness_feature'] = closeness_feature
if self._period_len is not None and self._period_len > 0:
feed_dict['period_feature'] = period_feature
if self._trend_len is not None and self._trend_len > 0:
feed_dict['trend_feature'] = trend_feature
return feed_dict