简体   繁体   中英

In tensorflow2.0, if I use tf.keras.models.Model. Can I evaluate and save the model by the number of model training batches?

Now I am using the interface under tf.keras to encapsulate my CNN(Convolutional Neural Network).

In the model.fit interface, I found that the parameter validation_freq only provides epoch options. This means that when training a model, you can only evaluate and save the model after training one or more epochs. In many cases, it is more desirable to evaluate and save the model based on the number of model training batches. Does model.fit or model.fit_generator provide this option now?

You could do by implementing custom callback like this:

import tensorflow as tf
print(tf.__version__) # 2.1.0

class PerBatchLogsCallback(tf.keras.callbacks.Callback):
    def __init__(self, x_test, y_test, interval=2):
        self.x_test = x_test
        self.y_test = y_test
        self.n_batches = 0
        self.interval = interval
        self.all_logs = None

    def on_batch_end(self, batch, logs=None):
        if self.all_logs is None:
            self.all_logs = {name: [] for name in self.model.metrics_names}
            self.all_logs['batch'] = []

        if self.n_batches % self.interval == 0:
            evaluated = self.model.evaluate(self.x_test, self.y_test)
            for e, name in zip(evaluated, self.model.metrics_names):
                self.all_logs[name].append(e)
            self.all_logs['batch'].append(batch)
        self.n_batches += 1

Usage example that evaluates validation metrics every two batches:

import tensorflow as tf 
import numpy as np

inputs = tf.keras.layers.Input((2,))
res = tf.keras.layers.Dense(2, activation=tf.nn.softmax)(inputs)
model = tf.keras.Model(inputs, res)
model.compile(optimizer=tf.keras.optimizers.SGD(0.001),
              metrics=['accuracy'],
              loss='sparse_categorical_crossentropy')

x_train = np.random.normal(0, 1, (10, 2))
y_train = np.random.randint(0, 2, 10)

x_test = np.random.normal(0, 1, (5, 2))
y_test = np.random.randint(0, 2, 5)

batch_callback = PerBatchLogsCallback(x_test, y_test)

model.fit(x_train,
          y_train,
          batch_size=2,
          epochs=1,
          callbacks=[batch_callback])

print(batch_callback.all_logs)
# {'loss': [0.3391808867454529, 0.33950310945510864, 0.3395831882953644],
#  'accuracy': [1.0, 1.0, 1.0],
#  'batch': [0, 2, 4]}

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM