简体   繁体   中英

How to use keras fit_generator properly

As the title says, I'm trying to use the fit_generator method from keras.

I'm working with images of 50x50. After some pre-processing, this is what I have:

print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)

(122, 50, 50, 1)
(122, 15)
(41, 50, 50, 1)
(41, 15)

This is the generator (which comes from here ):

def generator(features, labels, batch_size):
    # Create empty arrays to contain batch of features and labels#
    batch_features = np.zeros((batch_size, size, size, 1))
    batch_labels = np.zeros((batch_size, n_targets))
    while True:
        for i in range(batch_size):
            # choose random index in features
            index = random.choice(len(features),1)
            batch_features[i] = features[index]
            batch_labels[i] = labels[index]
        yield batch_features, batch_labels

And I called using:

batch_size = 32
start_time = time.time()

model = create_model()
hist = model.fit_generator(generator(X_train, y_train, batch_size=batch_size),
                           steps_per_epoch=X_train.shape[0] // batch_size,
                           epochs=50, verbose=0, validation_data=(X_test, y_test))
# hist = model.fit(X_train, y_train, batch_size=16, epochs=100, verbose=0, validation_data=(X_test, y_test))

print("--- %s seconds ---" % (time.time() - start_time))

However this gives me an error:

ValueError: output of generator should be a tuple `(x, y, sample_weight)` or `(x, y)`. Found: None

I ran your code and I believe that the problem is that you should change the line index = random.choice(len(features),1) into index = np.random.choice(len(features),1) (note the added np. )

Because when I run the code I get two errors after each other, the first indicates that random cannot be found. The seconds indicates that no tuples are yielded. So perhaps you might have missed the first error?

When I change the line in question everything seems to work fine.

By the way, size and n_targets also need to be defined in the scope of course.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM