Source code for UCTB.train.EarlyStopping

from scipy import stats


[docs]class EarlyStopping(object): """Early stop if a span of newest records are not better than the current best record. Args: patience (int): The span of checked newest records. Attributes: __record_list (list): List of records. __best (float): The current best record. __patience (int): The span of checked newest records. __p (int): The number of newest records that are worse than the current best record. """ def __init__(self, patience): self.__record_list = [] self.__best = None self.__patience = patience self.__p = 0
[docs] def stop(self, new_value): """Append the new record to the record list and check if the number of new records than are worse than the best records exceeds the limit. Args: new_value (float): The new record generated by the newest model. Returns: bool: ``True`` if the number of new records than are worse than the best records exceeds the limit and triggers early stop, otherwise ``False``. """ self.__record_list.append(new_value) if self.__best is None or new_value < self.__best: self.__best = new_value self.__p = 0 return False else: if self.__p < self.__patience: self.__p += 1 return False else: return True
[docs]class EarlyStoppingTTest(object): """Early Stop by t-test. T-test is a two-sided test for the null hypothesis that 2 independent samples have identical average (expected) values. This method takes two intervals according to ``length`` in the record list and see if they have identical average values. If so, do early stop. Args: length (int): The length of checked interval. p_value_threshold (float): The p-value threshold to decide whether to do early stop. Attributes: __record_list (list): List of records. __best (float): The current best record. __test_length (int): The length of checked interval. __p_value_threshold (float): The p-value threshold to decide whether to do early stop. """ def __init__(self, length, p_value_threshold): self.__record_list = [] self.__best = None self.__test_length = length self.__p_value_threshold = p_value_threshold
[docs] def stop(self, new_value): """ Take two intervals in the record list to do t-test. Args: new_value (float): The new record generated by the newest model. Returns: bool: ``True`` if p value of t-test is smaller than threshold and triggers early stop, otherwise ``False``. """ self.__record_list.append(new_value) if len(self.__record_list) >= (self.__test_length * 2): lossTTest = stats.ttest_ind(self.__record_list[-self.__test_length:], self.__record_list[-self.__test_length * 2:-self.__test_length], equal_var=False) ttest = lossTTest[0] pValue = lossTTest[1] print('ttest:', ttest, 'pValue', pValue) if pValue > self.__p_value_threshold or ttest > 0: return True else: return False else: return False