简体   繁体   English

如何在Keras中保存训练期间保存的模型列表(仅保存最佳)?

[英]How to keep a list of models saved during training (save best only) in Keras?

I am using Modelcheckpoint feature to save my models based upon "save best only" criteria. 我正在使用Modelcheckpoint功能根据“仅保存最佳”条件保存模型。

file_name = str(datetime.datetime.now()).split(' ')[0] + f'{model_name}'+ '_{epoch:02d}.hdf5'

checkpoint_main = ModelCheckpoint(filename, monitor='val_acc', verbose=2,
                                save_best_only=True, save_weights_only=False,
                                mode='auto', period=1)

Since I am using "save best only" it will save only certain epochs. 由于我使用的是“仅保存最佳”,因此只会保存某些时期。 I want to collect the paths of actually saved models and save them to a list that I can access at the end of training. 我想收集实际保存的模型的路径,并将它们保存到训练结束时可以访问的列表。 This list will be piped to other operations. 该列表将通过管道传输到其他操作。

I tried looking at the source code, but I didn't see any examples of "train_end" that returns a list, so I am not really sure how to return something at the end of training. 我尝试查看源代码,但是没有看到返回列表的“ train_end”示例,因此我不确定在培训结束时如何返回某些内容。

https://github.com/keras-team/keras/blob/master/keras/callbacks.py#L360 https://github.com/keras-team/keras/blob/master/keras/callbacks.py#L360

If you wnat to store all the paths to saved models for each epoch, you could use a Callback , because Callback is just a python object and can collect data. 如果您要为每个时期存储保存的模型的所有路径,则可以使用Callback ,因为Callback只是一个python对象,可以收集数据。

For exmaple, it can store a paths to the models in a list: 例如,它可以在列表中存储模型的路径:

import datetime

class SaveEveryEpoch(Callback):
    def __init__(self, model_name, *args, **kwargs):
        self.model_checkpoint_paths = []
        self.model_name = model_name
        super().__init__(*args, **kwargs)

    def on_epoch_end(self, epoch, logs):
        # I suppose here it's a Functional model
        print(logs['acc'])
        path_to_checkpoint = (
            str(datetime.datetime.now()).split(' ')[0] 
            + f'_{self.model_name}'
            + f'_{epoch:02d}.hdf5'
        )
        self.model.save(path_to_checkpoint)
        self.model_checkpoint_paths.append(path_to_checkpoint)
  • __init__ initializes an empty list and stores a model basic name __init__初始化一个空列表并存储模型基本名称
  • on_epoch_end saves the model at the end of each epoch; on_epoch_end在每个时期结束时保存模型; also it appends a model path to list of model paths 还将模型路径追加到模型路径列表中

Example of usage 使用例

from keras.datasets import mnist
from keras.models import Model
from keras.layers import Dense, Input
import numpy as np

(X_tr, y_tr), (X_te, y_te) = mnist.load_data()
X_tr = (X_tr / 255.).reshape((60000, 784))
X_te = (X_te / 255.).reshape((10000, 784))


def binarize_labels(y):
    y_bin = np.zeros((len(y), len(np.unique(y)))) 
    y_bin[range(len(y)), y] = 1
    return y_bin

y_train_bin, y_test_bin = binarize_labels(y_tr), binarize_labels(y_te)

model = Sequential()
model.add(InputLayer((784,)))
model.add(Dense(784, activation='relu'))
model.add(Dense(256, activation='relu'))
model.add(Dense(10, activation='softmax'))

model = Model(inp, out)
model.compile(loss='categorical_crossentropy', optimizer='adam')

Run the model with checkpoints 使用检查点运行模型

checkpoints = SaveEveryEpoch('mnist_model')
history = model.fit(X_tr, y_train_bin, callbacks=[checkpoints], epochs=3)
# ... training progress ...
checkpoints.model_checkpoint_paths
Out:
['2019-02-06_mnist_model_00.hdf5',
 '2019-02-06_mnist_model_01.hdf5',
 '2019-02-06_mnist_model_02.hdf5']

ls output: ls输出:

2019-02-06_mnist_model_00.hdf5
2019-02-06_mnist_model_01.hdf5
2019-02-06_mnist_model_02.hdf5

Variations 变化

Modify on_epoch_end to create some collection which can be ordered by a loss , for example ( logs argument contains a dictionary with loss and a metric named acc if some metric was provided). 修改on_epoch_end以创建一些可以按loss排序的集合,例如( logs参数包含具有loss的字典和一个名为acc的度量(如果提供了某些度量)。 So, you can select a model with minimal loss/metric value later: 因此,您以后可以选择损耗/度量值最小的模型:

class SaveEveryEpoch(Callback):
    def __init__(self, model_name, *args, **kwargs):
        self.model_checkpoints_with_loss = []
        self.model_name = model_name
        super().__init__(*args, **kwargs)

    def on_epoch_end(self, epoch, logs):
        # I suppose here it's a Functional model
        print(logs['acc'])
        path_to_checkpoint = (
            str(datetime.datetime.now()).split(' ')[0] 
            + f'_{self.model_name}'
            + f'_{epoch:02d}.hdf5'
        )
        self.model.save(path_to_checkpoint)
        self.model_checkpoints_with_loss.append((logs['loss'], path_to_checkpoint))

Also, you can overload a default callback, for example, ModelCheckpoint to save all the paths, not only the best model, but I guess it's unnecessary in this case. 另外,您可以重载默认回调,例如ModelCheckpoint,以保存所有路径,不仅是最佳模型,而且我想在这种情况下没有必要。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM