[英]Load and continue running from the epoch at which they stopped
我正在训练一个巨大的模型。 不幸的是,运行时环境中断了大约一半,我必须重新启动模型。我在每个 epoch 后保存模型。
但是我现在的问题是,例如,我已经训练了 10 个 epcoh 中的 5 个。 我如何加载它并表明我在第 5 个时代并且他必须在那里继续,所以只需要经历 5 个时代? 我知道我可以加载模型,但我怎么能说我在第 5 个时期,现在你只需要经历 5 个时期,因为我总共想要 10 个时期。
cp_callback = [tf.keras.callbacks.ModelCheckpoint(
filepath='/saved/model.h5',
verbose=1,
save_weights_only=True,
save_freq= 'epoch'),
tf.keras.callbacks.EarlyStopping(monitor='loss', patience=2)]
您可以将纪元编号保存在单独的文件(pickle 或 json 文件)中。
import json
train_parameters = {'iter': iteration, 'batch_size': batch_size'}
# saving
json.dump(trainParameters, open(output_path+"trainParameters.txt",'w'))
# loading
trainParameters = json.load(open(path_to_saved_model+"trainParameters.txt"))
input = tf.random.uniform([8, 24], 0, 100, dtype=tf.int32)
model.compile(optimizer=optimizer, loss=training_loss, metrics=evaluation_accuracy)
hist = model.fit((input, input), input, epochs=1,
steps_per_epoch=1, verbose=0)
model.load_weights(path_to_saved_model+'saved.h5')
但是如果您需要保存学习率步骤 - 保存优化器状态。 状态包含迭代次数(通过的批次数)。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.