I am doing multi-class classification for a recommender system (item recommendations), and I'm currently training my network using sparse_categorical_crossentropy
loss. Therefore, it is reasonable to perform EarlyStopping
by monitoring my validation loss, val_loss
as such:
tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10)
which works as expected. However, the performance of the network (recommender system) is measured by Average-Precision-at-10, and is tracked as a metric during training, as average_precision_at_k10
. Because of this, I could also perform early stopping with this metric as such:
tf.keras.callbacks.EarlyStopping(monitor='average_precision_at_k10', patience=10)
which also works as expected.
My problem: Sometimes the validation loss increases, whilst the Average-Precision-at-10 is improving and vice-versa. Because of this, I would need to monitor both , and perform early stopping, if and only if both are deteriorating. What I would like to do:
tf.keras.callbacks.EarlyStopping(monitor=['val_loss', 'average_precision_at_k10'], patience=10)
which obviously does not work. Any ideas how this could be done?
With guidance from Gerry P above I managed to create my own custom EarlyStopping callback, and thought I post it here in case anyone else are looking to implement something similar.
If both the validation loss and the mean average precision at 10 does not improve for patience
number of epochs, early stopping is performed.
class CustomEarlyStopping(keras.callbacks.Callback):
def __init__(self, patience=0):
super(CustomEarlyStopping, self).__init__()
self.patience = patience
self.best_weights = None
def on_train_begin(self, logs=None):
# The number of epoch it has waited when loss is no longer minimum.
self.wait = 0
# The epoch the training stops at.
self.stopped_epoch = 0
# Initialize the best as infinity.
self.best_v_loss = np.Inf
self.best_map10 = 0
def on_epoch_end(self, epoch, logs=None):
v_loss=logs.get('val_loss')
map10=logs.get('val_average_precision_at_k10')
# If BOTH the validation loss AND map10 does not improve for 'patience' epochs, stop training early.
if np.less(v_loss, self.best_v_loss) and np.greater(map10, self.best_map10):
self.best_v_loss = v_loss
self.best_map10 = map10
self.wait = 0
# Record the best weights if current results is better (less).
self.best_weights = self.model.get_weights()
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
print("Restoring model weights from the end of the best epoch.")
self.model.set_weights(self.best_weights)
def on_train_end(self, logs=None):
if self.stopped_epoch > 0:
print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))
It is then used as:
model.fit(
x_train,
y_train,
batch_size=64,
steps_per_epoch=5,
epochs=30,
verbose=0,
callbacks=[CustomEarlyStopping(patience=10)],
)
You can achieve this by by creating a custom callback. Information on how to do that is located here. Below is some code that illustrates what you can do in a custom callback. The documentation I referenced shows many other options.
class LRA(keras.callbacks.Callback): # subclass the callback class
# create class variables as below. These can be accessed in your code outside the class definition as LRA.my_class_variable, LRA.best_weights
my_class_variable=something # a class variable
best_weights=model.get_weights() # another class variable
# define an initialization function with parameters you want to feed to the callback
def __init__(self, param1, param2, etc):
super(LRA, self).__init__()
self.param1=param1
self.param2=param2
etc for all parameters
# write any initialization code you need here
def on_epoch_end(self, epoch, logs=None): # method runs on the end of each epoch
v_loss=logs.get('val_loss') # example of getting log data at end of epoch the validation loss for this epoch
acc=logs.get('accuracy') # another example of getting log data
LRA.best_weights=model.get_weights() # example of setting class variable value
print(f'Hello epoch {epoch} has just ended') # print a message at the end of every epoch
lr=float(tf.keras.backend.get_value(self.model.optimizer.lr)) # get the current learning rate
if v_loss > self.param1:
new_lr=lr * self.param2
tf.keras.backend.set_value(model.optimizer.lr, new_lr) # set the learning rate in the optimizer
# write whatever code you need
I recommend you to create your own callback. In the following I added a solution that monitors both the accuracy and the loss. You can replace the acc with your own metric:
class CustomCallback(keras.callbacks.Callback):
acc = {}
loss = {}
best_weights = None
def __init__(self, patience=None):
super(CustomCallback, self).__init__()
self.patience = patience
def on_epoch_end(self, epoch, logs=None):
epoch += 1
self.loss[epoch] = logs['loss']
self.acc[epoch] = logs['accuracy']
if self.patience and epoch > self.patience:
# best weight if the current loss is less than epoch-patience loss. Simiarly for acc but when larger
if self.loss[epoch] < self.loss[epoch-self.patience] and self.acc[epoch] > self.acc[epoch-self.patience]:
self.best_weights = self.model.get_weights()
else:
# to stop training
self.model.stop_training = True
# Load the best weights
self.model.set_weights(self.best_weights)
else:
# best weight are the current weights
self.best_weights = self.model.get_weights()
Please bear in mind that if you want to control the minimum change in the monitored quantity (aka. min_delta) you have to integrate it in the code.
Here is the documentation for how to build your custome callback: custom_callback
At this point it would be more simple to make a custom loop and just use if-statements. Eg:
def main(epochs=50):
for epoch in range(epochs):
fit(epoch)
if test_acc.result() > .8 and topk_acc.result() > .9:
print(f'\nEarly stopping. Test acc is above 80% and TopK acc is above 90%.')
break
if __name__ == '__main__':
main(epochs=100)
Here's a simple custom training loop using this method:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow_datasets as tfds
import tensorflow as tf
data, info = tfds.load('iris', split='train',
as_supervised=True,
shuffle_files=True,
with_info=True)
def preprocessing(inputs, targets):
scaled = tf.divide(inputs, tf.reduce_max(inputs, axis=0))
return scaled, targets
dataset = data.filter(lambda x, y: tf.less_equal(y, 2)).\
map(preprocessing).\
shuffle(info.splits['train'].num_examples)
train_dataset = dataset.take(120).batch(4)
test_dataset = dataset.skip(120).take(30).batch(4)
model = tf.keras.Sequential([
tf.keras.layers.Dense(8, activation='relu'),
tf.keras.layers.Dense(16, activation='relu'),
tf.keras.layers.Dense(info.features['label'].num_classes, activation='softmax')
])
loss_object = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
train_loss = tf.metrics.Mean()
test_loss = tf.metrics.Mean()
train_acc = tf.metrics.SparseCategoricalAccuracy()
test_acc = tf.metrics.SparseCategoricalAccuracy()
topk_acc = tf.metrics.SparseTopKCategoricalAccuracy(k=2)
opt = tf.keras.optimizers.Adam(learning_rate=1e-3)
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
logits = model(inputs)
loss = loss_object(labels, logits)
gradients = tape.gradient(loss, model.trainable_variables)
opt.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
train_acc(labels, logits)
@tf.function
def test_step(inputs, labels):
logits = model(inputs)
loss = loss_object(labels, logits)
test_loss.update_state(loss)
test_acc.update_state(labels, logits)
topk_acc.update_state(labels, logits)
def fit(epoch):
template = 'Epoch {:>2} Train Loss {:.3f} Test Loss {:.3f} ' \
'Train Acc {:.2f} Test Acc {:.2f} Test TopK Acc {:.2f} '
train_loss.reset_states()
test_loss.reset_states()
train_acc.reset_states()
test_acc.reset_states()
topk_acc.reset_states()
for X_train, y_train in train_dataset:
train_step(X_train, y_train)
for X_test, y_test in test_dataset:
test_step(X_test, y_test)
print(template.format(
epoch + 1,
train_loss.result(),
test_loss.result(),
train_acc.result(),
test_acc.result(),
topk_acc.result()
))
def main(epochs=50):
for epoch in range(epochs):
fit(epoch)
if test_acc.result() > .8 and topk_acc.result() > .9:
print(f'\nEarly stopping. Test acc is above 80% and TopK acc is above 90%.')
break
if __name__ == '__main__':
main(epochs=100)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.