簡體   English   中英

在 tensorflow2.0 中,如果我使用 tf.keras.models.Model。 我可以通過模型訓練批次的數量來評估和保存模型嗎?

[英]In tensorflow2.0, if I use tf.keras.models.Model. Can I evaluate and save the model by the number of model training batches?

現在我使用tf.keras下的接口來封裝我的CNN(Convolutional Neural Network)。

model.fit界面,我發現參數validation_freq只提供了epoch選項。 這意味着在訓練模型時,您只能在訓練一個或多個 epoch 后評估和保存模型。 在很多情況下,更希望根據模型訓練批次的數量來評估和保存模型。 model.fitmodel.fit_generator現在提供這個選項嗎?

您可以通過實現這樣的自定義回調來實現:

import tensorflow as tf
print(tf.__version__) # 2.1.0

class PerBatchLogsCallback(tf.keras.callbacks.Callback):
    def __init__(self, x_test, y_test, interval=2):
        self.x_test = x_test
        self.y_test = y_test
        self.n_batches = 0
        self.interval = interval
        self.all_logs = None

    def on_batch_end(self, batch, logs=None):
        if self.all_logs is None:
            self.all_logs = {name: [] for name in self.model.metrics_names}
            self.all_logs['batch'] = []

        if self.n_batches % self.interval == 0:
            evaluated = self.model.evaluate(self.x_test, self.y_test)
            for e, name in zip(evaluated, self.model.metrics_names):
                self.all_logs[name].append(e)
            self.all_logs['batch'].append(batch)
        self.n_batches += 1

每兩批評估驗證指標的使用示例:

import tensorflow as tf 
import numpy as np

inputs = tf.keras.layers.Input((2,))
res = tf.keras.layers.Dense(2, activation=tf.nn.softmax)(inputs)
model = tf.keras.Model(inputs, res)
model.compile(optimizer=tf.keras.optimizers.SGD(0.001),
              metrics=['accuracy'],
              loss='sparse_categorical_crossentropy')

x_train = np.random.normal(0, 1, (10, 2))
y_train = np.random.randint(0, 2, 10)

x_test = np.random.normal(0, 1, (5, 2))
y_test = np.random.randint(0, 2, 5)

batch_callback = PerBatchLogsCallback(x_test, y_test)

model.fit(x_train,
          y_train,
          batch_size=2,
          epochs=1,
          callbacks=[batch_callback])

print(batch_callback.all_logs)
# {'loss': [0.3391808867454529, 0.33950310945510864, 0.3395831882953644],
#  'accuracy': [1.0, 1.0, 1.0],
#  'batch': [0, 2, 4]}

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM