简体   繁体   中英

How to call callback after n epochs but always in the last epoch of training?

I want to call a callback after n epochs, but always in the last epoch of training. Here explains how I can call the callback after n epochs.

At the moment I am using the following approach:

class MyCallBack(keras.callbacks.Callback):

    def on_epoch_end(self, epoch, log=None)

        if epoch % 10 == 0:  # <- add additional condition here
            self._do_the_stuff()
            
            
    def _do_the_stuff(self):
        print('Do the stuff')
        
        
    def on_training_end(self, logs=None):
        self._do_the_stuff()

Is there a simpler way where I add an additional condition to the if statement inside on_epoch_end and don't need on_training_end ?

As suggested by @Ewran in the comments above, it is possible to access the total number of epochs by `self.params['epochs'].

class MyCallBack(keras.callbacks.Callback):

    def on_epoch_end(self, epoch, log=None)

        if epoch % self.epoch_freq == 0 or epoch == self.params.get('epochs', -1):
            self._do_the_stuff()
            
            
    def _do_the_stuff(self):
        print('Do the stuff')
        
        
    def on_training_end(self, logs=None):
        self._do_the_stuff()

If other callbacks such as tf.keras.callbacks.EarlyStopping are used, I would continue to use the approach with on_train_end . Otherwise it is not guaranteed that the callback is called after the last epoch.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM