简体   繁体   English

如何正确使用keras fit_generator

[英]How to use keras fit_generator properly

As the title says, I'm trying to use the fit_generator method from keras. 如标题所示,我正在尝试使用来自keras的fit_generator方法。

I'm working with images of 50x50. 我正在使用50x50的图片。 After some pre-processing, this is what I have: 经过一些预处理,这就是我所拥有的:

print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)

(122, 50, 50, 1)
(122, 15)
(41, 50, 50, 1)
(41, 15)

This is the generator (which comes from here ): 这是生成器(来自此处 ):

def generator(features, labels, batch_size):
    # Create empty arrays to contain batch of features and labels#
    batch_features = np.zeros((batch_size, size, size, 1))
    batch_labels = np.zeros((batch_size, n_targets))
    while True:
        for i in range(batch_size):
            # choose random index in features
            index = random.choice(len(features),1)
            batch_features[i] = features[index]
            batch_labels[i] = labels[index]
        yield batch_features, batch_labels

And I called using: 我打电话给:

batch_size = 32
start_time = time.time()

model = create_model()
hist = model.fit_generator(generator(X_train, y_train, batch_size=batch_size),
                           steps_per_epoch=X_train.shape[0] // batch_size,
                           epochs=50, verbose=0, validation_data=(X_test, y_test))
# hist = model.fit(X_train, y_train, batch_size=16, epochs=100, verbose=0, validation_data=(X_test, y_test))

print("--- %s seconds ---" % (time.time() - start_time))

However this gives me an error: 但这给我一个错误:

ValueError: output of generator should be a tuple `(x, y, sample_weight)` or `(x, y)`. Found: None

I ran your code and I believe that the problem is that you should change the line index = random.choice(len(features),1) into index = np.random.choice(len(features),1) (note the added np. ) 我运行了代码,我相信问题是您应该将行index = random.choice(len(features),1) index = np.random.choice(len(features),1)index = np.random.choice(len(features),1) (请注意np.

Because when I run the code I get two errors after each other, the first indicates that random cannot be found. 因为当我运行代码时,我会彼此遇到两个错误,第一个错误表明找不到随机数。 The seconds indicates that no tuples are yielded. 秒表示没有元组产生。 So perhaps you might have missed the first error? 所以也许您可能错过了第一个错误?

When I change the line in question everything seems to work fine. 当我换线时,一切似乎都正常。

By the way, size and n_targets also need to be defined in the scope of course. 顺便说一下, n_targets还需要在size范围内定义sizen_targets

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM