Source code for UCTB.train.MiniBatchTrain
import numpy as np
[docs]class MiniBatchTrain():
def __init__(self, X, Y, batch_size):
# The first dimension of X should be sample size
# The first dimension of Y should be sample size
self.__X, self.__Y = self.shuffle(X, Y)
self.__sample_size = len(X)
self.__batch_counter = 0
self.__batch_size = batch_size
self.num_batch = int(self.__sample_size / self.__batch_size) \
if self.__sample_size % self.__batch_size == 0 else int(self.__sample_size / self.__batch_size) + 1
[docs] @staticmethod
def shuffle(X, Y):
xy = list(zip(X, Y))
np.random.shuffle(xy)
return np.array([e[0] for e in xy], dtype=np.float32), np.array([e[1] for e in xy], dtype=np.float32)
[docs] def get_batch(self):
if self.__batch_counter + self.__batch_size <= self.__sample_size:
batch_x = self.__X[self.__batch_counter: self.__batch_counter + self.__batch_size]
batch_y = self.__Y[self.__batch_counter: self.__batch_counter + self.__batch_size]
self.__batch_counter = self.__batch_counter + self.__batch_size
else:
batch_x = self.__X[-self.__batch_size: ]
batch_y = self.__Y[-self.__batch_size: ]
self.__batch_counter = 0
return batch_x, batch_y
[docs] def restart(self):
self.__batch_counter = 0
[docs]class MiniBatchTrainMultiData(object):
def __init__(self, data, batch_size, shuffle=True):
if shuffle:
self.__data = self.shuffle(data)
else:
self.__data = data
self.__sample_size = len(self.__data[0])
self.__batch_counter = 0
self.__batch_size = batch_size
self.num_batch = int(self.__sample_size / self.__batch_size) \
if self.__sample_size % self.__batch_size == 0 else int(self.__sample_size / self.__batch_size) + 1
[docs] @staticmethod
def shuffle(data):
middle = list(zip(*data))
np.random.shuffle(middle)
return list(zip(*middle))
[docs] def get_batch(self):
if self.__batch_counter + self.__batch_size <= self.__sample_size:
index = [self.__batch_counter, self.__batch_counter + self.__batch_size]
self.__batch_counter = self.__batch_counter + self.__batch_size
else:
index = [self.__sample_size-self.__batch_size, self.__sample_size]
self.__batch_counter = 0
return [np.array(e[index[0]: index[1]]) for e in self.__data]
[docs] def restart(self):
self.__batch_counter = 0
[docs]class MiniBatchFeedDict(object):
def __init__(self, feed_dict, sequence_length, batch_size, shuffle=True):
self._sequence_length = sequence_length
self._batch_size = batch_size
self._dynamic_data_names = []
self._dynamic_data_values = []
self._batch_dict = {}
for key, value in feed_dict.items():
if len(value) == sequence_length:
self._dynamic_data_names.append(key)
self._dynamic_data_values.append(value)
else:
self._batch_dict[key] = value
if shuffle:
self._dynamic_data_values = MiniBatchFeedDict.shuffle(self._dynamic_data_values)
self._batch_counter = 0
self.num_batch = int(self._sequence_length / self._batch_size) \
if self._sequence_length % self._batch_size == 0 else int(self._sequence_length / self._batch_size) + 1
[docs] def get_batch(self):
if self._batch_counter + self._batch_size <= self._sequence_length:
index = [self._batch_counter, self._batch_counter + self._batch_size]
self._batch_counter += self._batch_size
else:
index = [self._sequence_length-self._batch_size, self._sequence_length]
self._batch_counter = 0
for i in range(len(self._dynamic_data_names)):
key = self._dynamic_data_names[i]
self._batch_dict[key] = np.array(self._dynamic_data_values[i][index[0]:index[1]])
return self._batch_dict
[docs] @staticmethod
def shuffle(data):
middle = list(zip(*data))
np.random.shuffle(middle)
return list(zip(*middle))
[docs] def restart(self):
self._batch_counter = 0