Source code for UCTB.model.HM

import numpy as np

import warnings
warnings.filterwarnings("ignore")


[docs]class HM(object): def __init__(self, c, p, t): self.c = c self.p = p self.t = t if self.c == 0 and self.p == 0 and self.t == 0: raise ValueError('c p t cannot all be zero at the same time')
[docs] def predict(self, closeness_feature, period_feature, trend_feature): prediction = [] if self.c > 0: prediction.append(closeness_feature[:, :, :, 0]) if self.p > 0: prediction.append(period_feature[:, :, :, 0]) if self.t > 0: prediction.append(trend_feature[:, :, :, 0]) prediction = np.mean(np.concatenate(prediction, axis=-1), axis=-1, keepdims=True) return prediction