[英]Keras custom generator when batch_size doesn't match with amount of data
I'm using Keras with Python 2.7.我在 Python 2.7 中使用 Keras。 I'm making my own data generator to compute batches for the train.
我正在制作自己的数据生成器来计算火车的批次。 I have some question about data_generator based on this model seen here :
我有一些关于基于这里看到的模型的 data_generator 的问题:
class DataGenerator(keras.utils.Sequence):
def __init__(self, list_IDs, ...):
#init
def __len__(self):
return int(np.floor(len(self.list_IDs) / self.batch_size))
def __getitem__(self, index):
indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
# Find list of IDs
list_IDs_temp = [self.list_IDs[k] for k in indexes]
# Generate data
X, y = self.__data_generation(list_IDs_temp)
return X, y
def on_epoch_end(self):
'Updates indexes after each epoch'
self.indexes = np.arange(len(self.list_IDs))
if self.shuffle == True:
np.random.shuffle(self.indexes)
def __data_generation(self, list_IDs_temp):
#generate data
return X, y
Okay, so here are my several questions :好的,这是我的几个问题:
Can you confirm my thinking about the order of function called ?你能确认我对调用函数顺序的想法吗? Here is :
这是 :
- __init__
- loop for each epoc :
- loop for each batches :
- __len_
- __get_item__ (+data generation)
- on_epoch_end
If you know a way to debug the generator I would like to know it, breakpoint and prints aren't working with this..如果您知道调试生成器的方法,我想知道它,断点和打印不适用于此..
More, I have a bad situation, but I think that everybody have the problem :更多,我的情况很糟糕,但我认为每个人都有问题:
For example, I have 200 datas (and 200 labels ok) and I want a batch size of 64 for example.例如,我有 200 个数据(和 200 个标签可以),例如我想要 64 的批量大小。 If I'm thinking well, __len_ will give 200/64 = 3 (instead of 3,125).
如果我想得好,__len_ 将给出 200/64 = 3(而不是 3,125)。 So 1 epoch will be done with 3 batches ?
那么 1 个 epoch 将用 3 个批次完成? What about the rest of the data ?
剩下的数据呢? I have an error because my amount of data is not a multiple of the batch size...
我有一个错误,因为我的数据量不是批量大小的倍数...
Second example, I have 200 data and I want a batch of 256 ?第二个例子,我有 200 个数据,我想要一批 256 ? What I have to do in this case to adapt my generator ?
在这种情况下我必须做些什么来适应我的发电机? I thought about checking if the batch_size is superior to my amount of data to feed the CNN with 1 batch, but the batch will not have the expected size so I thinks it will make an error ?
我想检查 batch_size 是否优于我的数据量,以便用 1 个批次为 CNN 提供数据,但是批次没有预期的大小,所以我认为它会出错?
Thanks you for the reading.感谢您的阅读。 I prefer to put pseudo-code because my questions are more about theory than coding errors !
我更喜欢放置伪代码,因为我的问题更多是关于理论而不是编码错误!
Normally you never mention the batch size in the model architecture, because it is a training parameter not a model parameter.通常你不会在模型架构中提到批量大小,因为它是一个训练参数而不是模型参数。 So it is OK to have different batch sizes while training.
所以在训练时有不同的批量大小是可以的。
from keras.models import Sequential
from keras.layers import Dense, Conv2D, Flatten
from keras.utils import to_categorical
import keras
#create model
model = Sequential()
#add model layers
model.add(Conv2D(64, kernel_size=3, activation='relu', input_shape=(10,10,1)))
model.add(Flatten())
model.add(Dense(2, activation='softmax'))
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
class DataGenerator(keras.utils.Sequence):
def __init__(self, X, y, batch_size):
self.X = X
self.y = y
self.batch_size = batch_size
def __len__(self):
l = int(len(self.X) / self.batch_size)
if l*self.batch_size < len(self.X):
l += 1
return l
def __getitem__(self, index):
X = self.X[index*self.batch_size:(index+1)*self.batch_size]
y = self.y[index*self.batch_size:(index+1)*self.batch_size]
return X, y
X = np.random.rand(200,10,10,1)
y = to_categorical(np.random.randint(0,2,200))
model.fit_generator(DataGenerator(X,y,13), epochs=10)
Output:输出:
Epoch 1/10 16/16 [==============================] - 0s 2ms/step - loss: 0.6774 - acc: 0.6097
As you can see it has run 16 batches in one epoch ie 13*15+5=200
如您所见,它在一个 epoch 中运行了 16 个批次,即
13*15+5=200
Your generator is used in your python environment by Keras, if you cannot debug it, the reason is elsewhere. Keras在你的python环境中使用了你的生成器,如果你不能调试它,原因在别处。
cf : https://keras.io/utils/#sequence参考: https : //keras.io/utils/#sequence
__len__
: gives you the number of minibatches __len__
:给你小批量的数量
__getitem__
: gives the ith minibatch __getitem__
: 给出第 i 个小批量
You don't have to know when or where they are called but more like this :您不必知道他们何时何地被调用,但更像是这样:
- __init__
- __len_
- loop for each epoc :
- loop for each batches :
- __get_item__
- on_epoch_end
As for the minibatch size, you have two (classic) choices, either truncate or fill by picking again entries from your set.至于小批量大小,您有两个(经典)选择,要么截断,要么通过从您的集合中再次挑选条目来填充。 If you randomize your trainset every epoch as you should, there will be no overexposure or underexposure of some items over time
如果您按照应有的方式在每个时期随机化您的训练集,那么随着时间的推移,某些项目不会出现过度曝光或曝光不足的情况
The debugging aspect of this question sounds like the same question I tried posting recently and didn't get an answer.这个问题的调试方面听起来与我最近尝试发布但没有得到答案的问题相同。 I eventually figured it out and I think it's a simple principle that is easily missed by beginners.
我最终想通了,我认为这是一个简单的原则,初学者很容易错过。 You cannot break/debug at the level of the keras source code, if it's tensorflow-gpu underneath.
如果下面是 tensorflow-gpu,则无法在 keras 源代码级别进行中断/调试。 That keras code got 'translated' to run on the gpu.
该 keras 代码被“翻译”以在 gpu 上运行。 I thought perhaps it would be possible to break if running tensorflow on the cpu, but no that's not possible either.
我想如果在 cpu 上运行 tensorflow 可能会中断,但这也不可能。 There are ways to debug/break on the gpu at the tensorflow level, but that's gone beyond the simplicity of the high-level keras.
有一些方法可以在 tensorflow 级别在 gpu 上调试/中断,但这超出了高级 keras 的简单性。
You might be able to debug the generator if you pass "run_eagerly=True" in model.compile function.如果您在 model.compile 函数中传递“run_eagerly=True”,您可能能够调试生成器。 It says here :
它在这里说:
Running eagerly means that your model will be run step by step, like Python code.急切地运行意味着您的模型将逐步运行,就像 Python 代码一样。 Your model might run slower, but it should become easier for you to debug it by stepping into individual layer calls.
您的模型可能运行得较慢,但您应该更容易通过进入各个层调用来调试它。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.