Now I am using the interface under tf.keras
to encapsulate my CNN(Convolutional Neural Network).
In the model.fit
interface, I found that the parameter validation_freq
only provides epoch
options. This means that when training a model, you can only evaluate and save the model after training one or more epochs. In many cases, it is more desirable to evaluate and save the model based on the number of model training batches. Does model.fit
or model.fit_generator
provide this option now?
You could do by implementing custom callback like this:
import tensorflow as tf
print(tf.__version__) # 2.1.0
class PerBatchLogsCallback(tf.keras.callbacks.Callback):
def __init__(self, x_test, y_test, interval=2):
self.x_test = x_test
self.y_test = y_test
self.n_batches = 0
self.interval = interval
self.all_logs = None
def on_batch_end(self, batch, logs=None):
if self.all_logs is None:
self.all_logs = {name: [] for name in self.model.metrics_names}
self.all_logs['batch'] = []
if self.n_batches % self.interval == 0:
evaluated = self.model.evaluate(self.x_test, self.y_test)
for e, name in zip(evaluated, self.model.metrics_names):
self.all_logs[name].append(e)
self.all_logs['batch'].append(batch)
self.n_batches += 1
Usage example that evaluates validation metrics every two batches:
import tensorflow as tf
import numpy as np
inputs = tf.keras.layers.Input((2,))
res = tf.keras.layers.Dense(2, activation=tf.nn.softmax)(inputs)
model = tf.keras.Model(inputs, res)
model.compile(optimizer=tf.keras.optimizers.SGD(0.001),
metrics=['accuracy'],
loss='sparse_categorical_crossentropy')
x_train = np.random.normal(0, 1, (10, 2))
y_train = np.random.randint(0, 2, 10)
x_test = np.random.normal(0, 1, (5, 2))
y_test = np.random.randint(0, 2, 5)
batch_callback = PerBatchLogsCallback(x_test, y_test)
model.fit(x_train,
y_train,
batch_size=2,
epochs=1,
callbacks=[batch_callback])
print(batch_callback.all_logs)
# {'loss': [0.3391808867454529, 0.33950310945510864, 0.3395831882953644],
# 'accuracy': [1.0, 1.0, 1.0],
# 'batch': [0, 2, 4]}
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.