简体   繁体   English

在 tensorflow2.0 中,如果我使用 tf.keras.models.Model。 我可以通过模型训练批次的数量来评估和保存模型吗?

[英]In tensorflow2.0, if I use tf.keras.models.Model. Can I evaluate and save the model by the number of model training batches?

Now I am using the interface under tf.keras to encapsulate my CNN(Convolutional Neural Network).现在我使用tf.keras下的接口来封装我的CNN(Convolutional Neural Network)。

In the model.fit interface, I found that the parameter validation_freq only provides epoch options.model.fit界面,我发现参数validation_freq只提供了epoch选项。 This means that when training a model, you can only evaluate and save the model after training one or more epochs.这意味着在训练模型时,您只能在训练一个或多个 epoch 后评估和保存模型。 In many cases, it is more desirable to evaluate and save the model based on the number of model training batches.在很多情况下,更希望根据模型训练批次的数量来评估和保存模型。 Does model.fit or model.fit_generator provide this option now? model.fitmodel.fit_generator现在提供这个选项吗?

You could do by implementing custom callback like this:您可以通过实现这样的自定义回调来实现:

import tensorflow as tf
print(tf.__version__) # 2.1.0

class PerBatchLogsCallback(tf.keras.callbacks.Callback):
    def __init__(self, x_test, y_test, interval=2):
        self.x_test = x_test
        self.y_test = y_test
        self.n_batches = 0
        self.interval = interval
        self.all_logs = None

    def on_batch_end(self, batch, logs=None):
        if self.all_logs is None:
            self.all_logs = {name: [] for name in self.model.metrics_names}
            self.all_logs['batch'] = []

        if self.n_batches % self.interval == 0:
            evaluated = self.model.evaluate(self.x_test, self.y_test)
            for e, name in zip(evaluated, self.model.metrics_names):
                self.all_logs[name].append(e)
            self.all_logs['batch'].append(batch)
        self.n_batches += 1

Usage example that evaluates validation metrics every two batches:每两批评估验证指标的使用示例:

import tensorflow as tf 
import numpy as np

inputs = tf.keras.layers.Input((2,))
res = tf.keras.layers.Dense(2, activation=tf.nn.softmax)(inputs)
model = tf.keras.Model(inputs, res)
model.compile(optimizer=tf.keras.optimizers.SGD(0.001),
              metrics=['accuracy'],
              loss='sparse_categorical_crossentropy')

x_train = np.random.normal(0, 1, (10, 2))
y_train = np.random.randint(0, 2, 10)

x_test = np.random.normal(0, 1, (5, 2))
y_test = np.random.randint(0, 2, 5)

batch_callback = PerBatchLogsCallback(x_test, y_test)

model.fit(x_train,
          y_train,
          batch_size=2,
          epochs=1,
          callbacks=[batch_callback])

print(batch_callback.all_logs)
# {'loss': [0.3391808867454529, 0.33950310945510864, 0.3395831882953644],
#  'accuracy': [1.0, 1.0, 1.0],
#  'batch': [0, 2, 4]}

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 使用 Keras 在 Tensorflow 2.0 中无需训练即可保存模型 - Save model without training in Tensorflow 2.0 with Keras Tensorflow2.0培训:model.compile与GradientTape - Tensorflow2.0 training: model.compile vs GradientTape 您可以将训练参数与 tf.keras.Model() 构造函数一起使用吗? - Can you use the training parameter with tf.keras.Model() constructor? How can I use tf.keras.Model.summary to see the layers of a child model which in a father model? - How can I use tf.keras.Model.summary to see the layers of a child model which in a father model? 如何保存使用 tf.keras.estimator.model_to_estimator 创建的 tensorflow 估算器? - How do I save a tensorflow estimator created with tf.keras.estimator.model_to_estimator? 是否可以使用张量流回调将纪元结果记录在tf.keras模型中,以便在训练结束时保存? - Is it possible to log the epoch results in the tf.keras model using a tensorflow callback, in order to save at the end of training? 如何使用 tf.keras.Model - TensorFlow 2.0 - 子类化 API 保存和恢复模式的权重 - How to save and restore a mode's weights with tf.keras.Model - TensorFlow 2.0 - Subclassing API 使用 TF 2.0 为 Tensorflow/Keras model 提供嵌入层问题 - Issue with embedding layer when serving a Tensorflow/Keras model with TF 2.0 model = tf.keras.models.load_model() - model = tf.keras.models.load_model() 如何使用 tensorflow 和 keras 加快我的 model 训练过程 - How can i speed up my model training process using tensorflow and keras
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM