![](/img/trans.png)
[英]Tensorflow — Cannot call `tf.keras.Model.add_metric` when `tf.distribute.MirroredStrategy` is used
[英]Can't load tensorflow keras checkpoint when using tf.distribute.MirroredStrategy()
我正在嘗試從使用 ModelCheckpoint 回調創建的檢查點加載 tf.keras (v1.15.0) 模型,通過刪除多個層並添加新層來修改它,然后繼續在新任務上對其進行訓練。 我正在使用 tf.distribute.MirroredStrategy() 用 2 gpus 進行分布式訓練。
strategy = tensorflow.distribute.MirroredStrategy()
with strategy.scope():
# Load pretrained model from checkpoint
model = get_model()
model.load_weights('file_name.hdf5')
# Chop off some layers, add new layers
model = modify_pretrained_model(model)
model.compile(optimizer=opt, loss=loss)
模型加載良好並編譯,我可以運行 model.summary(),但是當我調用 model.fit() 或 model.predict() 時,我的 python 堆棧中出現以下錯誤:
(0) Failed precondition: Error while reading resource variable compression0_conv0_batchnorm/moving_variance from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/compression0_conv0_batchnorm/moving_variance/N10tensorflow3VarE does not exist.
[[{{node time_distributed_1/model_1/compression0_conv0_batchnorm/FusedBatchNormV3/ReadVariableOp_1}}]]
[[dense_1_1/Sigmoid/_225]]
(1) Failed precondition: Error while reading resource variable compression0_conv0_batchnorm/moving_variance from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/compression0_conv0_batchnorm/moving_variance/N10tensorflow3VarE does not exist.
[[{{node time_distributed_1/model_1/compression0_conv0_batchnorm/FusedBatchNormV3/ReadVariableOp_1}}]]
0 successful operations.
1 derived errors ignored
這個問題似乎解決了這個確切的問題,但沒有使用 tf.distribute 繼續訓練。
當我在分發范圍之外實例化一個會話,並在分發范圍內設置對它的引用時,代碼會因相同的錯誤而崩潰。
tf_config = some_custom_config
sess = tf.Session(config=tf_config)
graph = tf.get_default_graph()
strategy = tensorflow.distribute.MirroredStrategy()
with strategy.scope():
set_session(sess)
# Load pretrained model from checkpoint
model = get_model()
model.load_weights('file_name.hdf5')
# Chop off some layers, add new layers
model = modify_pretrained_model(model)
model.compile(optimizer=opt, loss=loss)
我花了 2-3 天的時間試圖弄清楚這一點。 唯一真正有效的是升級到 tf 2.0.0。 然后一切都像魔術一樣運作。 或者作為最后的手段,我能夠訓練第一個模型,添加和刪除額外的層,重新編譯,並使用相同的分發策略在同一個 python 執行中繼續訓練,但永遠無法使用分發策略重新加載 tf.keras ModelCheckpoint在 tf 1.15.0 中。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.