简体   繁体   中英

A callback to check the saturation of val_acc

Usually, we can define a callback for a model to stop the epoch if the accuracy reaches a certain level.

I am working on the adjustment of parameters. The val_acc is highly unstable as shown in the picture acc_graphs .

def LSTM_model(X_train, y_train, X_test, y_test, num_classes, batch_size=68, units=128, learning_rate=0.005, epochs=20,
               dropout=0.2, recurrent_dropout=0.2):
    class myCallback(tf.keras.callbacks.Callback):

        def on_epoch_end(self, epoch, logs={}):
            if (logs.get('acc') > 0.90):
                print("\nReached 90% accuracy so cancelling training!")
                self.model.stop_training = True

    callbacks = myCallback()

As the graphs show that the val_acc(orange) is fluctuating within a range and not really going up anymore.

Is there a way to automatically stop the training once the general trend of the val_acc stops increasing?

You can achieve this with a callback like this

class terminate_on_plateau(keras.callbacks.Callback):
    
    def __init__(self):
        self.patience = 10
        self.val_loss = deque([],self.patience)
        self.std_threshold = 1e-2
        
    def on_epoch_end(self,epoch,logs=None):
        val_loss,val_mae = model.evaluate(x_val,y_val)
        self.val_loss.append(val_loss)
        if len(self.val_loss) >= self.patience:
            std = np.std(self.val_loss)
            if std < self.std_threshold:
                print('\n\n EarlyStopping on std invoked! \n\n')
                # clear the deque
                self.val_loss = deque([],self.patience)
                model.stop_training = True

As you can see, in terminate_on_plateau , val_loss of epochs are stored in a deque of max length self.patience . Once the length of the deque reaches self.patience , standard deviation of the val_loss will be calculated for every new epoch , and the training process will be terminated (the deque of val_loss will also be cleared), if the calculated std is smaller than a threshold.

Below is a simple script that shows you how to use this

from collections import deque
import numpy as np

import tensorflow as tf 
from tensorflow import keras 
import tensorflow.keras.backend as K
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input,Dense

x = np.linspace(0,10,1000)
np.random.shuffle(x)
y = np.sin(x) + x

x_train,x_val,y_train,y_val = train_test_split(x,y,test_size=0.3)

input_x = Input(shape=(1,))
y = Dense(10,activation='relu')(input_x)
y = Dense(10,activation='relu')(y)
y = Dense(1,activation='relu')(y)
model = Model(inputs=input_x,outputs=y)

adamopt = tf.keras.optimizers.Adam(lr=0.01, beta_1=0.9, beta_2=0.999, epsilon=1e-8)

class terminate_on_plateau(keras.callbacks.Callback):
    
    def __init__(self):
        self.patience = 10
        self.val_loss = deque([],self.patience)
        self.std_threshold = 1e-2
        
    def on_epoch_end(self,epoch,logs=None):
        val_loss,val_mae = model.evaluate(x_val,y_val)
        self.val_loss.append(val_loss)
        if len(self.val_loss) >= self.patience:
            std = np.std(self.val_loss)
            if std < self.std_threshold:
                print('\n\n EarlyStopping on std invoked! \n\n')
                # clear the deque
                self.val_loss = deque([],self.patience)
                model.stop_training = True
    
model.compile(loss='mse',optimizer=adamopt,metrics=['mae'])
history = model.fit(x_train,y_train,
                    batch_size=8,
                    epochs=100,
                    validation_data=(x_val, y_val),
                    verbose=1,
                    callbacks=[terminate_on_plateau()])

The code below is for a custom callback that will stop training when the quantity being monitored fails to improve after patience number of epochs. Set the parameter acc_or_loss to 'loss' in order to monitor validation loss. Set it to 'acc' to monitor validation accuracy. I recommend NOT to monitor validation accuracy as it can swing wildly particularly in the early epochs. I put in print statements so you can see what is going on during training. You can of course remove them later. If you are monitoring validation loss the call back halts training if for a patience number of epochs the validation loss has exceeded the lowest loss found in the previous epochs. If you are monitoring validation accuracy the callback halts training if for a patience number of epochs the validation accuracy has stayed below the highest validation accuracy recorded in the previous epochs

class halt(keras.callbacks.Callback):
    def __init__(self, patience, acc_or_loss):
        self.acc_or_loss=acc_or_loss
        super(halt, self).__init__()
        self.patience=patience # specifies how many epochs without improvement before learning rate is adjusted
        self.lowest_loss=np.inf 
        self.highest_acc=0
        self.count=0
        print ('initializing values ', 'count= ', self.count, '  lowest_loss= ', self.lowest_loss, 'highest acc= ', self.highest_acc)
    def on_epoch_end(self, epoch, logs=None):
        v_loss=logs.get('val_loss')  # get the validation loss for this epoch
        v_acc=logs.get('val_accuracy')
        if self.acc_or_loss=='loss':
            print (' for epoch ', epoch +1, '  v_loss= ', v_loss, ' lowest_loss= ', self.lowest_loss,  'count= ', self.count)
            if v_loss< self.lowest_loss:
                self.lowest_loss=v_loss
                self.count=0
            else:
                self.count=self.count +1
                if self.count>=self.patience:
                    print('There have been ', self.patience, ' epochs with no reduction of validation loss below the lowest loss')
                    print ('Terminating training')
                    self.model.stop_training = True
        else:
            print (' for epoch ', epoch +1, '  v_acc= ', v_acc, ' highest accuracy= ', self.highest_acc,  'count= ', self.count)
            if v_acc>self.highest_acc:
                self.count=0
                self.highest_acc=v_acc
            else:
                self.count=self.count +1
                if self.count>=self.patience:
                    print('There have been ', self.patience, ' epochs with noincrease in validation accuracy')
                    print ('Terminating training')
                    self.model.stop_training = True

patience= 2 #  specify the patience value
acc_or_loss='loss' # specify to monitor validation loss or validation accuracy
callbacks=[halt(patience=patience, acc_or_loss=acc_or_loss)]
# in model.fit include callbacks=callbacks

Or you can just use Keras API in tensorflow : tf.keras.callbacks.EarlyStopping

Given your initial question I'm not sure why you would need custom callbacks

Here is an example of application:

history = model.fit([trainX,trainX,trainX],
                    np.array(trainLabels),
                    validation_data = ([testX, testX, testX], np.array(testLabels)),
                    epochs=EPOCH,
                    batch_size=BATCH_SIZE,
                    steps_per_epoch = None,
                    callbacks=[tf.keras.callbacks.EarlyStopping(
                        monitor="val_acc",
                        patience=5,
                        mode="min",
                        restore_best_weights = True)])

Some of the above answers are a little complex, you can use the below code.

opt = tf.optimizers.Adadelta(learning_rate=0.01)
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=["accuracy"])
es = EarlyStopping(monitor='val_accuracy', mode='max', patience=20)
# will stop if validation accuracy is not improving till 20 epoches, you can give any number in patience.
ms = ModelCheckpoint('save_model.h5', monitor='val_accuracy', mode='max', save_best_only=True)

training_history = model.fit(x=X_train, y=y_train, validation_split=0.1, batch_size=5, epochs=1000, verbose=1,
                             callbacks = [es, ms])

I just copied this code from my project, which is not for LSTM, you can adjust this code according to your problem/task.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM