简体   繁体   中英

Evaluate model on Testing Set after each epoch of training

I'm training a tensorflow model on image dataset for a classification task, we usually provide the training set and validation set to the model.fit method, we can later output model convergence graph of training and validation. I want to do the same with the testing set, in other words, I want to get the accuracy and loss of my model on the testing set after each epoch(not validation set - and I can't replace the validation set with the testing set because I need graphs of both of them).

I managed to do that by saving the checkpoints of my model after each epoch using some callback and later load each checkpoint to my model and compute accuracy and loss, but I want to know if there exists some easier way of doing that, maybe with some other callback or some work around with the model.fit method.

You could use a custom Callback and pass your test data and do whatever you like:

import tensorflow as tf
import pathlib
import numpy as np

dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)

batch_size = 5

train_ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  seed=123,
  image_size=(64, 64),
  batch_size=batch_size)

test_ds = train_ds.take(30)

model = tf.keras.Sequential([
  tf.keras.layers.Rescaling(1./255, input_shape=(64, 64, 3)),
  tf.keras.layers.Conv2D(16, 3, padding='same', activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Conv2D(32, 3, padding='same', activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(5)
])

class TestCallback(tf.keras.callbacks.Callback):
    def __init__(self, test_dataset):
        super().__init__()
        self.test_dataset = test_dataset
        self.test_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
        self.loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) 

    def on_epoch_end(self, epoch, logs=None):
        losses = []
        for x_batch_test, y_batch_test in self.test_dataset:
          test_logits = self.model(x_batch_test, training=False)
          losses.append(self.loss_fn(y_batch_test, test_logits))
          self.test_acc_metric.update_state(y_batch_test, test_logits)
        test_acc = self.test_acc_metric.result()
        self.test_acc_metric.reset_states()
        logs['test_loss'] = tf.reduce_mean(tf.stack(losses))  # not sure if the reduction is correct
        logs['test_sparse_categorical_accuracy'] = test_acc

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) 
model.compile(optimizer='adam',
              loss=loss_fn,
              metrics=tf.keras.metrics.SparseCategoricalAccuracy())
epochs = 5
history = model.fit(train_ds, epochs=epochs, callbacks= [TestCallback(test_ds)])
Found 3670 files belonging to 5 classes.
Epoch 1/5
734/734 [==============================] - 14s 17ms/step - loss: 1.2709 - sparse_categorical_accuracy: 0.4591 - test_loss: 1.0020 - test_sparse_categorical_accuracy: 0.5533
Epoch 2/5
734/734 [==============================] - 13s 18ms/step - loss: 0.9574 - sparse_categorical_accuracy: 0.6275 - test_loss: 0.8348 - test_sparse_categorical_accuracy: 0.6467
Epoch 3/5
734/734 [==============================] - 9s 12ms/step - loss: 0.8136 - sparse_categorical_accuracy: 0.6733 - test_loss: 0.8379 - test_sparse_categorical_accuracy: 0.6467
Epoch 4/5
734/734 [==============================] - 8s 11ms/step - loss: 0.6970 - sparse_categorical_accuracy: 0.7357 - test_loss: 0.5713 - test_sparse_categorical_accuracy: 0.7533
Epoch 5/5
734/734 [==============================] - 8s 11ms/step - loss: 0.5793 - sparse_categorical_accuracy: 0.7834 - test_loss: 0.5656 - test_sparse_categorical_accuracy: 0.7733

You can also just use model.evaluate in the callback. See also this post .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM