3. Quick start

3.1. Quick start with STMeta

from UCTB.dataset import NodeTrafficLoader
from UCTB.model import STMeta
from UCTB.evaluation import metric
from UCTB.preprocess.GraphGenerator import GraphGenerator
# Config data loader
data_loader = NodeTrafficLoader(dataset='Bike', city='NYC', graph='Correlation',
                                closeness_len=6, period_len=7, trend_len=4, normalize=True)

# Build Graph
graph_obj = GraphGenerator(graph='Correlation', data_loader=data_loader)

# Init model object
STMeta_Obj = STMeta(closeness_len=data_loader.closeness_len,
                    period_len=data_loader.period_len,
                    trend_len=data_loader.trend_len,
                    num_node=data_loader.station_number,
                    num_graph=graph_obj.LM.shape[0],
                    external_dim=data_loader.external_dim)

# Build tf-graph
STMeta_Obj.build()
# Training
STMeta_Obj.fit(closeness_feature=data_loader.train_closeness,
               period_feature=data_loader.train_period,
               trend_feature=data_loader.train_trend,
               laplace_matrix=graph_obj.LM,
               target=data_loader.train_y,
               external_feature=data_loader.train_ef,
               sequence_length=data_loader.train_sequence_len)

# Predict
prediction = STMeta_Obj.predict(closeness_feature=data_loader.test_closeness,
                                period_feature=data_loader.test_period,
                                trend_feature=data_loader.test_trend,
                                laplace_matrix=graph_obj.LM,
                                target=data_loader.test_y,
                                external_feature=data_loader.test_ef,
                                output_names=['prediction'],
                                sequence_length=data_loader.test_sequence_len)

# Evaluate
print('Test result', metric.rmse(prediction=data_loader.normalizer.min_max_denormal(prediction['prediction']),
                                 target=data_loader.normalizer.min_max_denormal(data_loader.test_y), threshold=0))

3.2. Quick start with other models