繁体   English   中英

Python/Keras - 如何访问每个纪元预测?

[英]Python/Keras - How to access each epoch prediction?

我正在使用 Keras 来预测时间序列。 作为标准,我使用 20 个时代。 我想通过预测 20 个 epoch 中的每一个来检查我的模型是否学习良好。

通过使用model.predict()我在所有时代中只得到一个预测(不确定model.predict()如何选择它)。 我想要所有的预测,或者至少 10 个最好的。

有谁知道如何帮助我?

我认为这里有点混乱。

epoch 仅在训练神经网络时使用,因此当训练停止时(在这种情况下,在第 20 个 epoch 之后),则权重对应于在最后一个 epoch 上计算的权重。

Keras 在每个 epoch 之后的训练期间在验证集上打印当前损失值。 如果每个 epoch 之后的权重没有被保存,那么它们就会丢失。 您可以使用ModelCheckpoint回调为每个 epoch 保存权重,然后使用load_weights将它们加载回您的模型。

您可以通过在on_epoch_end函数内对Callback进行子类化并在模型上调用 predict 来实现适当的回调,从而在每个训练时期之后计算您的预测。

然后要使用它,您需要实例化您的回调,创建一个列表并将其用作model.fit 的关键字参数回调。

以下代码将完成所需的工作:

import tensorflow as tf
import keras

# define your custom callback for prediction
class PredictionCallback(tf.keras.callbacks.Callback):    
  def on_epoch_end(self, epoch, logs={}):
    y_pred = self.model.predict(self.validation_data[0])
    print('prediction: {} at epoch: {}'.format(y_pred, epoch))

# ...

# register the callback before training starts
model.fit(X_train, y_train, batch_size=32, epochs=25, 
          validation_data=(X_valid, y_valid), 
          callbacks=[PredictionCallback()])

如果你想对测试数据进行预测,你可以试试这个

class CustomCallback(keras.callbacks.Callback):
    def __init__(self, model, x_test, y_test):
        self.model = model
        self.x_test = x_test
        self.y_test = y_test

    def on_epoch_end(self, epoch, logs={}):
        y_pred = self.model.predict(self.x_test, self.y_test)
        print('y predicted: ', y_pred)

您需要在 model.fit 期间提及回调

model.sequence()
# your model architecture
model.fit(x_train, y_train, epochs=10, 
          callbacks=[CustomCallback(model, x_test, y_test)])

on_epoch_end类似,keras 提供了许多其他方法

on_train_begin, on_train_end, on_epoch_begin, on_epoch_end, on_test_begin,
on_test_end, on_predict_begin, on_predict_end, on_train_batch_begin, on_train_batch_end,
on_test_batch_begin, on_test_batch_end, on_predict_batch_begin,on_predict_batch_end

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM