Source code for UCTB.train.EarlyStopping
from scipy import stats
[docs]class EarlyStopping(object):
"""Early stop if a span of newest records are not better than the current best record.
Args:
patience (int): The span of checked newest records.
Attributes:
__record_list (list): List of records.
__best (float): The current best record.
__patience (int): The span of checked newest records.
__p (int): The number of newest records that are worse than the current best record.
"""
def __init__(self, patience):
self.__record_list = []
self.__best = None
self.__patience = patience
self.__p = 0
[docs] def stop(self, new_value):
"""Append the new record to the record list
and check if the number of new records than are worse than the best records exceeds the limit.
Args:
new_value (float): The new record generated by the newest model.
Returns:
bool: ``True`` if the number of new records than are worse than the best records exceeds the limit and
triggers early stop, otherwise ``False``.
"""
self.__record_list.append(new_value)
if self.__best is None or new_value < self.__best:
self.__best = new_value
self.__p = 0
return False
else:
if self.__p < self.__patience:
self.__p += 1
return False
else:
return True
[docs]class EarlyStoppingTTest(object):
"""Early Stop by t-test.
T-test is a two-sided test for the null hypothesis that 2 independent samples
have identical average (expected) values. This method takes two intervals according to ``length``
in the record list and see if they have identical average values. If so, do early stop.
Args:
length (int): The length of checked interval.
p_value_threshold (float): The p-value threshold to decide whether to do early stop.
Attributes:
__record_list (list): List of records.
__best (float): The current best record.
__test_length (int): The length of checked interval.
__p_value_threshold (float): The p-value threshold to decide whether to do early stop.
"""
def __init__(self, length, p_value_threshold):
self.__record_list = []
self.__best = None
self.__test_length = length
self.__p_value_threshold = p_value_threshold
[docs] def stop(self, new_value):
"""
Take two intervals in the record list to do t-test.
Args:
new_value (float): The new record generated by the newest model.
Returns:
bool: ``True`` if p value of t-test is smaller than threshold and
triggers early stop, otherwise ``False``.
"""
self.__record_list.append(new_value)
if len(self.__record_list) >= (self.__test_length * 2):
lossTTest = stats.ttest_ind(self.__record_list[-self.__test_length:],
self.__record_list[-self.__test_length * 2:-self.__test_length], equal_var=False)
ttest = lossTTest[0]
pValue = lossTTest[1]
print('ttest:', ttest, 'pValue', pValue)
if pValue > self.__p_value_threshold or ttest > 0:
return True
else:
return False
else:
return False