简体   繁体   中英

How to predict a new data series with a trained Keras GRU model?

I'm trying to use a trained Keras sequence model (GRU) to predict some new data samples, but have some problem creating the time series generator.

In the training process, the validation set was predicted using model.predict_generator() , which used a Python generator created by keras.preprocessing.sequence.TimeseriesGenerator() ( link ) as input. I wanted to repeat the process with a new test set, only to find that TimeseriesGenerator() needs both data and targets as input. But in this case, I expect to get the targets (ie y_test ) with the predict function.

A simplified version of my training code looks like this:

training_generator = TimeseriesGenerator(X_train, y_train, length=timesteps * sampling_rate, sampling_rate=sampling_rate, batch_size=batch_size, shuffle=shuffle_data)
test_generator = TimeseriesGenerator(X_test, y_test, length=timesteps * sampling_rate, sampling_rate=sampling_rate, batch_size=batch_size, shuffle=shuffle_data)

model.fit_generator(generator=training_generator, epochs=epochs, use_multiprocessing=False, verbose=2)
y_test_pred = model.predict_generator(generator=test_generator)

I also thought about writing a custom generator by myself, but then it would be really hard to validate the equivalence between this generator and the official time series generator.

Is there any way to use TimeseriesGenerator() without giving targets ?

Thank you for your help!

A simple workaround would be to generate dummy targets, since predict_generator will ignore them:

X_test =  your new test data
y_dummy = np.zeros((X_test.shape[0], ))
test_generator = TimeseriesGenerator(X_test, y_dummy, length=timesteps * sampling_rate, sampling_rate=sampling_rate, batch_size=batch_size, shuffle=shuffle_data)
y_test_pred = model.predict_generator(generator=test_generator)

You should adjust the shape of y_dummy if your labels are multidimensional.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM