简体   繁体   中英

Is there a way to save a model at a specified epoch in tf.keras?

Using tf.keras.callbacks, I am only able to auto-save a best model by picking one attribute to monitor (typically validation accuracy), but sometimes, I need it to save according to a comparison of validation and training accuracy. How can I do this?

Do tf.keras.history files record the model's weights at every epoch? If so, how can I save my model from the history file by specifying the epoch I want? That is another possible solution.

This is the situation I'm running into: On occasion, my validation accuracy is very high in an early epoch (purely by chance I suppose) while my training accuracy is still far below it. This epoch ends up being the model that is auto-saved. It's a crappy model because of its poor training accuracy, but it's the one that got saved because of its high validation accuracy. If it had saved at a place where the training and validation accuracy meet, it would have been a pretty good model. So at every epoch, I'd prefer to compare the training accuracy and the validation accuracy, choose the lowest of the two, and decide my best model based on that. Any suggestions on how to do that?

You can implement a Custom callback like this:

class CustomModelCheckpoint(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        # logs is a dictionary
        print(f"epoch: {epoch}, train_acc: {logs['acc']}, valid_acc: {logs['val_acc']}")
        if logs['val_acc'] > logs['acc']: # your custom condition
            self.model.save('model.h5', overwrite=True)

cbk = CustomModelCheckpoint()
model.fit(....callbacks=[cbk]...)

Checkout the callback ModelCheckpoint at https://keras.io/callbacks/

You can save the model for each epoch and include the accuracy/val accuracy in the filename (or check the history object afterwards).

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM