簡體   English   中英

從它們停止的紀元加載並繼續運行

[英]Load and continue running from the epoch at which they stopped

我正在訓練一個巨大的模型。 不幸的是,運行時環境中斷了大約一半,我必須重新啟動模型。我在每個 epoch 后保存模型。

但是我現在的問題是,例如,我已經訓練了 10 個 epcoh 中的 5 個。 我如何加載它並表明我在第 5 個時代並且他必須在那里繼續,所以只需要經歷 5 個時代? 我知道我可以加載模型,但我怎么能說我在第 5 個時期,現在你只需要經歷 5 個時期,因為我總共想要 10 個時期。

cp_callback = [tf.keras.callbacks.ModelCheckpoint(
    filepath='/saved/model.h5', 
    verbose=1, 
    save_weights_only=True,
    save_freq= 'epoch'),
    tf.keras.callbacks.EarlyStopping(monitor='loss', patience=2)]

您可以將紀元編號保存在單獨的文件(pickle 或 json 文件)中。

import json

train_parameters = {'iter': iteration, 'batch_size': batch_size'}

# saving

json.dump(trainParameters, open(output_path+"trainParameters.txt",'w'))

# loading

trainParameters = json.load(open(path_to_saved_model+"trainParameters.txt"))
input = tf.random.uniform([8, 24], 0, 100, dtype=tf.int32)
model.compile(optimizer=optimizer, loss=training_loss, metrics=evaluation_accuracy)
hist = model.fit((input, input), input, epochs=1,
                        steps_per_epoch=1, verbose=0)
model.load_weights(path_to_saved_model+'saved.h5')

但是如果您需要保存學習率步驟 - 保存優化器狀態。 狀態包含迭代次數(通過的批次數)。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM