[英]How does tensorflow-keras calculate the cost during training in each epoch?
[英]How can I call a test set at the end of each epoch during the training? I am using tensorflow
我正在使用 Tensorflow-Keras 開發一個 CNN 模型,其中我將數據集拆分為訓練、驗證和測試集。 我需要在每個時期結束時調用測試集以及訓練和驗證集來評估模型性能。 下面是我跟蹤訓練和驗證集的代碼。
result_dic = {"epochs": []}
json_logging_callback = LambdaCallback(
on_epoch_begin=lambda epoch, logs: [learning_rate],
on_epoch_end=lambda epoch, logs:
result_dic["epochs"].append({
'epoch': epoch + 1,
'acc': str(logs['acc']),
'val_acc': str(logs['val_acc'])
}))
model.fit(x_train, y_train,
validation_data=(x_val, y_val),
batch_size=batch_size,
epochs=epochs,
callbacks=[json_logging_callback])
輸出:
Epoch 1/5
1/1 [==============================] - 4s 4s/step - acc: 0.8611 - val_acc: 0.8333
但是,我不確定如何將測試集添加到我的回調中以產生以下輸出。
預期輸出:
Epoch 1/5
1/1 [==============================] - 4s 4s/step - acc: 0.8611 - val_acc: 0.8333 - test_acc: xxx
要在每個 epoch 后顯示您的測試准確率,您可以自定義fit
函數以顯示此指標。 看看這個文檔,或者你可以,如圖所示這里,定義了一個簡單的回調為您的測試數據集,並把它傳遞到您的fit
函數:
model.fit(x_train, y_train,
validation_data=(x_val, y_val),
batch_size=batch_size,
epochs=epochs,
callbacks=[json_logging_callback,
your_test_callback((X_test, Y_test))])
如果你想要完全的靈活性,你可以嘗試使用訓練循環。
更新:由於您希望為所有指標使用單個 JSON,您應該執行以下操作:
定義您的TestCallBack
並將您的測試准確性(如果需要,還包括loss
)添加到您的logs
字典中:
import tensorflow as tf
class TestCallback(tf.keras.callbacks.Callback):
def __init__(self, test_data):
self.test_data = test_data
def on_epoch_end(self, epoch, logs):
x, y = self.test_data
loss, acc = self.model.evaluate(x, y, verbose=0)
logs['test_accuracy'] = acc
然后將測試准確度添加到您的結果字典中:
result_dic = {"epochs": []}
json_logging_callback = tf.keras.callbacks.LambdaCallback(
on_epoch_begin=lambda epoch, logs: [learning_rate],
on_epoch_end=lambda epoch, logs:
result_dic["epochs"].append({
'epoch': epoch + 1,
'acc': str(logs['accuracy']),
'val_acc': str(logs['val_accuracy']),
'test_acc': str(logs['test_accuracy'])
}))
然后在您的fit
函數中使用這兩個回調,但請注意回調的順序:
model.fit(x_train, y_train,
validation_data=(x_val, y_val),
batch_size=batch_size,
epochs=epochs,
callbacks=callbacks=[TestCallback((x_test, y_test)), json_logging_callback])
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.