[英]How to use keras fit_generator properly
As the title says, I'm trying to use the fit_generator
method from keras. 如标题所示,我正在尝试使用来自keras的fit_generator
方法。
I'm working with images of 50x50. 我正在使用50x50的图片。 After some pre-processing, this is what I have: 经过一些预处理,这就是我所拥有的:
print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)
(122, 50, 50, 1)
(122, 15)
(41, 50, 50, 1)
(41, 15)
This is the generator (which comes from here ): 这是生成器(来自此处 ):
def generator(features, labels, batch_size):
# Create empty arrays to contain batch of features and labels#
batch_features = np.zeros((batch_size, size, size, 1))
batch_labels = np.zeros((batch_size, n_targets))
while True:
for i in range(batch_size):
# choose random index in features
index = random.choice(len(features),1)
batch_features[i] = features[index]
batch_labels[i] = labels[index]
yield batch_features, batch_labels
And I called using: 我打电话给:
batch_size = 32
start_time = time.time()
model = create_model()
hist = model.fit_generator(generator(X_train, y_train, batch_size=batch_size),
steps_per_epoch=X_train.shape[0] // batch_size,
epochs=50, verbose=0, validation_data=(X_test, y_test))
# hist = model.fit(X_train, y_train, batch_size=16, epochs=100, verbose=0, validation_data=(X_test, y_test))
print("--- %s seconds ---" % (time.time() - start_time))
However this gives me an error: 但这给我一个错误:
ValueError: output of generator should be a tuple `(x, y, sample_weight)` or `(x, y)`. Found: None
I ran your code and I believe that the problem is that you should change the line index = random.choice(len(features),1)
into index = np.random.choice(len(features),1)
(note the added np.
) 我运行了代码,我相信问题是您应该将行index = random.choice(len(features),1)
index = np.random.choice(len(features),1)
为index = np.random.choice(len(features),1)
(请注意np.
)
Because when I run the code I get two errors after each other, the first indicates that random cannot be found. 因为当我运行代码时,我会彼此遇到两个错误,第一个错误表明找不到随机数。 The seconds indicates that no tuples are yielded. 秒表示没有元组产生。 So perhaps you might have missed the first error? 所以也许您可能错过了第一个错误?
When I change the line in question everything seems to work fine. 当我换线时,一切似乎都正常。
By the way, size
and n_targets
also need to be defined in the scope of course. 顺便说一下, n_targets
还需要在size
范围内定义size
和n_targets
。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.