简体   繁体   English

在 keras 回调文件名中包含 epoch 和 model 编号

[英]Include epoch and model number in keras callback filename

I'm training models with various hyperparameters iteratively in a for loop and I want to use a keras callback to save multiple models in a folder.我在 for 循环中迭代地训练具有各种超参数的模型,我想使用 keras 回调将多个模型保存在一个文件夹中。 I have been able to save the model number in each model but now I would also like to include variables such as epoch number (and to save the model every 5 epochs).我已经能够在每个 model 中保存 model 编号,但现在我还想包括诸如纪元编号之类的变量(并每 5 次保存 model)。 In the following code, I add 1 to counter each time my for loop runs to denote the model number.在下面的代码中,每次我的 for 循环运行时,我都会在计数器上加 1,以表示 model 编号。

filepath = root_path + "/saved_models/model_number_{}.h5".format(counter)
history = final_model.fit(x_train, y_train,
                                batch_size=batch_size,
                                epochs=epochs,
                                validation_data=(x_train, y_train),
                                shuffle=True,
                                callbacks= tf.keras.callbacks.ModelCheckpoint(filepath=filepath, monitor='val_accuracy', verbose=0, save_weights_only=True, mode='auto', save_freq='epoch'),
                                )

I can also make this filepath to save the epoch number and accuracy in the file name but I can't join it with my model.我也可以创建这个文件filepath来保存文件名中的纪元号和准确性,但我不能将它与我的 model 一起加入。 Is there a way to do so?有没有办法这样做?

filepath = s3_root_path + "/saved_models/weights.{epoch:02d}-{val_loss:.2f}.h5"

There are slight changes you need to do while saving the model into the same folder for every 5 epochs:在每 5 个 epoch 将 model 保存到同一文件夹中时,您需要做一些细微的更改:

  1. Root path name should be the same where your are saving all the models (See the name difference root_path , s3_root_path in above code)根路径名称应该与您保存所有模型的位置相同(请参阅上面代码中的名称差异root_paths3_root_path
  2. File name format while saving the model should be correct保存 model 时的文件名格式应该是正确的

Please check the below fixed code:请检查以下固定代码:

#Created root 'SAVING/saved_models' folder for saving the entire checkpoints.
!mkdir -p SAVING/saved_models

#'SAVING' folder directory
!ls SAVING/   # Output :saved_models

Saving the model every 5 epochs by changing save_freq=5*batch_size通过更改save_freq=5*batch_size每 5 个 epoch 保存 model

root_path="SAVING/"
checkpoint_path = root_path + "saved_models/weights.{epoch:02d}.h5"  #-{val_loss:.2f}
checkpoint_dir = os.path.dirname(checkpoint_path)

batch_size = 32

cp_callback= tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, 
                                                monitor='val_sparse_categorical_accuracy', 
                                                verbose=1, 
                                                save_weights_only=True,  
                                                save_freq=5*batch_size)


history = model.fit(train_images, train_labels, 
                                batch_size=batch_size,
                                epochs=50,
                                validation_data=(test_images, test_labels),
                                shuffle=True,
                                callbacks=[cp_callback],verbose=0)

Output: Output:

Epoch 5: saving model to SAVING/saved_models/weights.05.h5

Epoch 10: saving model to SAVING/saved_models/weights.10.h5

Epoch 15: saving model to SAVING/saved_models/weights.15.h5

Epoch 20: saving model to SAVING/saved_models/weights.20.h5

Epoch 25: saving model to SAVING/saved_models/weights.25.h5

Epoch 30: saving model to SAVING/saved_models/weights.30.h5

Epoch 35: saving model to SAVING/saved_models/weights.35.h5

Epoch 40: saving model to SAVING/saved_models/weights.40.h5

Epoch 45: saving model to SAVING/saved_models/weights.45.h5

Epoch 50: saving model to SAVING/saved_models/weights.50.h5

To check saved checkpoints in the root folder directory:要检查根文件夹目录中保存的检查点:

!ls SAVING/saved_models

Output: Output:

weights.05.h5  weights.15.h5  weights.25.h5  weights.35.h5  weights.45.h5
weights.10.h5  weights.20.h5  weights.30.h5  weights.40.h5  weights.50.h5

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM