Source code for UCTB.model_unit.BaseModel

import os
import numpy as np
import shutil
import tensorflow as tf

from tensorboard.backend.event_processing import event_accumulator

from ..train.MiniBatchTrain import MiniBatchFeedDict
from ..preprocess.preprocessor import SplitData
from ..train.EarlyStopping import *


[docs]class BaseModel(object): """BaseModel is the base class for many models, such as STMeta, ST-MGCN and ST_ResNet, you can also build your own model using this class. More information can be found in tutorial. Args: code_version: Current version of this model code, which will be used as filename for saving the model. model_dir: The directory to store model files. Default:'model_dir'. gpu_device: To specify the GPU to use. Default: '0'. """ def __init__(self, code_version, model_dir, gpu_device): # model input and output self._input = {} self._output = {} self._op = {} self._variable_init = None self._saver = None self._code_version = code_version self._model_dir = model_dir # TF Graph self._graph = tf.Graph() self._converged = False self._log_dir = os.path.join(self._model_dir, self._code_version) self._global_step = 0 self._summary = None self._summary_writer = tf.summary.FileWriter(self._log_dir) self.trainable_vars = 0 # TF Session self._GPU_DEVICE = gpu_device os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = self._GPU_DEVICE self._config = tf.ConfigProto() self._config.gpu_options.allow_growth = True self._session = tf.Session(graph=self._graph, config=self._config)
[docs] def build(self, init_vars=True, max_to_keep=5): """ Args init_vars(bool): auto init the parameters if set to True, else no parameters will be initialized. max_to_keep: max file to keep, which equals to max_to_keep in tf.train.Saver. """ with self._graph.as_default(): #################################################################### # Add summary, variable_init and summary # The variable name of them are fixed self.trainable_vars = np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]) self._saver = tf.train.Saver(max_to_keep=max_to_keep) self._variable_init = tf.global_variables_initializer() self._summary = self._summary_histogram().name #################################################################### if init_vars: self._session.run(self._variable_init)
[docs] def add_summary(self, name, value, global_step): value_record = tf.Summary(value=[tf.Summary.Value(tag=name, simple_value=value)]) self._summary_writer.add_summary(value_record, global_step)
def _summary_histogram(self): with self._graph.as_default(): for var in tf.trainable_variables(): tf.summary.histogram(var.name, var) self._summary_writer.add_graph(self._graph) return tf.summary.merge_all() def _run(self, feed_dict, output_names, op_names): feed_dict_tf = {} for name, value in feed_dict.items(): if value is not None: feed_dict_tf[self._graph.get_tensor_by_name(self._input[name])] = value output_tensor_list = [self._graph.get_tensor_by_name(self._output[name]) for name in output_names] output_tensor_list += [self._graph.get_operation_by_name(self._op[name]) for name in op_names] outputs = self._session.run(output_tensor_list, feed_dict=feed_dict_tf) return {output_names[i]: outputs[i] for i in range(len(output_names))} def _get_feed_dict(self, **kwargs): return kwargs
[docs] def fit(self, sequence_length, output_names=('loss', ), op_names=('train_op', ), evaluate_loss_name='loss', batch_size=64, max_epoch=10000, validate_ratio=0.1, shuffle_data=True, early_stop_method='t-test', early_stop_length=10, early_stop_patience=0.1, verbose=True, save_model=True, save_model_name=None, auto_load_model=True, return_outputs=False, **kwargs): """ Args: sequence_length: int, the sequence length which is use in mini-batch training output_names: list, [output_tensor1_name, output_tensor1_name, ...] op_names: list, [operation1_name, operation2_name, ...] evaluate_loss_name: str, should be on of the output_names, evaluate_loss_name was use in early-stopping batch_size: int, default 64, batch size max_epoch: int, default 10000, max number of epochs validate_ratio: float, default 0.1, the ration of data that will be used as validation dataset shuffle_data: bool, default True, whether shuffle data in mini-batch train early_stop_method: should be 't-test' or 'naive', both method are explained in train.EarlyStopping early_stop_length: int, must provide when early_stop_method='t-test' early_stop_patience: int, must provide when early_stop_method='naive' verbose: Bool, flag to print training information or not save_model: Bool, flog to save model or not save_model_name: String, filename for saving the model, which will overwrite the code_version. auto_load_model: Bool, the "fit" function will automatically load the model from disk, if exists, before the training. Set to False to disable the auto-loading. return_outputs: Bool, set True to return the training log, otherwise nothing will be returned """ if auto_load_model: try: self.load(self._code_version) print('Found model in disk') if self._converged: print('Model converged, stop training') return else: print('Model not converged, continue at step', self._global_step) start_epoch = self._global_step except FileNotFoundError: print('No model found, start training') start_epoch = 0 else: start_epoch = 0 print('Not loading model from disk') if not 0 < validate_ratio < 1: raise ValueError('validate_ratio should between (0, 1), given', validate_ratio) if evaluate_loss_name not in output_names: raise ValueError('evaluate_loss_name not shown in', output_names) if len(op_names) == 0: raise ValueError('No operation given') else: print('Running Operation', op_names) # Get feed_dict feed_dict = self._get_feed_dict(**kwargs) # Split data into train-data and validation data train_feed_dict, val_feed_dict = SplitData.split_feed_dict(feed_dict, sequence_length=sequence_length, ratio_list=[1 - validate_ratio, validate_ratio]) train_sequence_length = int(sequence_length*(1-validate_ratio)) val_sequence_len = sequence_length - train_sequence_length # build mini-batch data source on train-data train_dict_mini_batch = MiniBatchFeedDict(feed_dict=train_feed_dict, sequence_length=train_sequence_length, batch_size=batch_size, shuffle=shuffle_data) # record the best result of "evaluate_loss_name" best_record = None # init early stopping object if early_stop_method.lower() == 't-test': early_stop = EarlyStoppingTTest(length=early_stop_length, p_value_threshold=early_stop_patience) else: early_stop = EarlyStopping(patience=int(early_stop_patience)) # start mini-batch training summary_output = [] for epoch in range(start_epoch, max_epoch): train_output_list = [] for i in range(train_dict_mini_batch.num_batch): # train train_output = self._run(feed_dict=train_dict_mini_batch.get_batch(), output_names=output_names, op_names=op_names) train_output_list.append(train_output) # validation val_output = self.predict(**val_feed_dict, output_names=output_names, sequence_length=val_sequence_len, cache_volume=batch_size) # Here we only care about the evaluate_loss_value evaluate_loss_value = np.mean(val_output[evaluate_loss_name]) # Add Summary tmp_summary = {} for name in output_names: self.add_summary(name='train_' + name, value=np.mean([e[name] for e in train_output_list]), global_step=epoch) self.add_summary(name='val_' + name, value=np.mean(val_output[name]), global_step=epoch) # print training messages if verbose: print('Epoch %s:' % epoch, 'train_' + name, np.mean([e[name] for e in train_output_list]), 'val_' + name, np.mean(val_output[name])) tmp_summary['train_' + name] = np.mean([e[name] for e in train_output_list]) tmp_summary['val_' + name] = np.mean(val_output[name]) summary_output.append(tmp_summary) # manual_summary the histograms self.manual_summary(global_step=epoch) if early_stop.stop(evaluate_loss_value): if save_model: self._log('Converged') break # save the model if evaluate_loss_value is smaller than best_record if (best_record is None or evaluate_loss_value < best_record) and save_model: best_record = evaluate_loss_value self.save(save_model_name or self._code_version, epoch) if return_outputs: return summary_output
[docs] def predict(self, sequence_length, output_names=('prediction', ), cache_volume=64, **kwargs): ''' Args: output_names: list, [output_tensor_name1, output_tensor_name2, ...] sequence_length: int, the length of sequence, which is use in mini-batch training cache_volume: int, default 64, we need to set cache_volume if the cache can not hold the whole validation dataset :return: outputs_dict: dict, like {output_tensor1_name: output_tensor1_value, ...} ''' # Get feed_dict feed_dict = self._get_feed_dict(**kwargs) if cache_volume and sequence_length: # storing the prediction result outputs_list = [] outputs_dict = {} for i in range(0, sequence_length, cache_volume): tmp_output = self._run({key: value[i:i+cache_volume] if len(value) == sequence_length else value for key, value in feed_dict.items()}, output_names, op_names=[]) outputs_list.append(tmp_output) # stack the output together for key in outputs_list[0]: outputs_dict[key] = np.vstack([e[key] for e in outputs_list]) else: outputs_dict = self._run(feed_dict, output_names, op_names=[]) return outputs_dict
[docs] def manual_summary(self, global_step=None): self._summary_writer.add_summary(self._session.run(self._graph.get_tensor_by_name(self._summary)), global_step=global_step)
def _log(self, text): save_dir_subscript = os.path.join(self._log_dir, self._code_version) if os.path.isdir(save_dir_subscript) is False: os.makedirs(save_dir_subscript) with open(os.path.join(save_dir_subscript, 'log.txt'), 'a+', encoding='utf-8') as f: f.write(text + '\n') def _get_log(self): save_dir_subscript = os.path.join(self._log_dir, self._code_version) if os.path.isfile(os.path.join(save_dir_subscript, 'log.txt')): with open(os.path.join(save_dir_subscript, 'log.txt'), 'r', encoding='utf-8') as f: return [e.strip('\n') for e in f.readlines()] else: return []
[docs] def save(self, subscript, global_step): """ Args: subscript: String, subscript will be appended to the code version as the model filename, and save the corresponding model using this filename global_step: Int, current training steps """ save_dir_subscript = os.path.join(self._log_dir, subscript) # delete if exist # if os.path.isdir(save_dir_subscript): # shutil.rmtree(save_dir_subscript, ignore_errors=True) if os.path.isdir(save_dir_subscript) is False: os.makedirs(save_dir_subscript) self._saver.save(sess=self._session, save_path=os.path.join(save_dir_subscript, subscript), global_step=global_step)
[docs] def load(self, subscript): """ Args: subscript: String, subscript will be appended to the code version as the model file name, and load the corresponding model using this filename """ save_dir_subscript = os.path.join(self._log_dir, subscript) if len(os.listdir(save_dir_subscript)) == 0: print('model Not Found') raise FileNotFoundError(subscript, 'model not found') else: meta_file = [e for e in os.listdir(save_dir_subscript) if e.startswith(subscript) and e.endswith('.meta')] self._global_step = max([int(e.split('.')[0].split('-')[-1]) for e in meta_file]) self._saver.restore(sess=self._session, save_path=os.path.join(save_dir_subscript, subscript + '-%s' % self._global_step)) self._global_step += 1 # parse the log-file log_list = self._get_log() for e in log_list: if e.lower() == 'converged': self._converged = True
[docs] def close(self): """ Close the session, release memory. """ self._session.close()
[docs] def load_event_scalar(self, scalar_name='val_loss'): """ Args: scalar_name: load the corresponding scalar name from tensorboard-file, e.g. load_event_scalar('val_loss) """ event_files = [e for e in os.listdir(self._log_dir) if e.startswith('events.out')] result = [] for f in event_files: ea = event_accumulator.EventAccumulator(os.path.join(self._log_dir, f)) ea.Reload() if scalar_name in ea.scalars.Keys(): result += [[e.wall_time, e.step, e.value] for e in ea.scalars.Items(scalar_name)] return result