![](/img/trans.png)
[英]Tensorflow — Cannot call `tf.keras.Model.add_metric` when `tf.distribute.MirroredStrategy` is used
[英]Can't load tensorflow keras checkpoint when using tf.distribute.MirroredStrategy()
我正在尝试从使用 ModelCheckpoint 回调创建的检查点加载 tf.keras (v1.15.0) 模型,通过删除多个层并添加新层来修改它,然后继续在新任务上对其进行训练。 我正在使用 tf.distribute.MirroredStrategy() 用 2 gpus 进行分布式训练。
strategy = tensorflow.distribute.MirroredStrategy()
with strategy.scope():
# Load pretrained model from checkpoint
model = get_model()
model.load_weights('file_name.hdf5')
# Chop off some layers, add new layers
model = modify_pretrained_model(model)
model.compile(optimizer=opt, loss=loss)
模型加载良好并编译,我可以运行 model.summary(),但是当我调用 model.fit() 或 model.predict() 时,我的 python 堆栈中出现以下错误:
(0) Failed precondition: Error while reading resource variable compression0_conv0_batchnorm/moving_variance from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/compression0_conv0_batchnorm/moving_variance/N10tensorflow3VarE does not exist.
[[{{node time_distributed_1/model_1/compression0_conv0_batchnorm/FusedBatchNormV3/ReadVariableOp_1}}]]
[[dense_1_1/Sigmoid/_225]]
(1) Failed precondition: Error while reading resource variable compression0_conv0_batchnorm/moving_variance from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/compression0_conv0_batchnorm/moving_variance/N10tensorflow3VarE does not exist.
[[{{node time_distributed_1/model_1/compression0_conv0_batchnorm/FusedBatchNormV3/ReadVariableOp_1}}]]
0 successful operations.
1 derived errors ignored
这个问题似乎解决了这个确切的问题,但没有使用 tf.distribute 继续训练。
当我在分发范围之外实例化一个会话,并在分发范围内设置对它的引用时,代码会因相同的错误而崩溃。
tf_config = some_custom_config
sess = tf.Session(config=tf_config)
graph = tf.get_default_graph()
strategy = tensorflow.distribute.MirroredStrategy()
with strategy.scope():
set_session(sess)
# Load pretrained model from checkpoint
model = get_model()
model.load_weights('file_name.hdf5')
# Chop off some layers, add new layers
model = modify_pretrained_model(model)
model.compile(optimizer=opt, loss=loss)
我花了 2-3 天的时间试图弄清楚这一点。 唯一真正有效的是升级到 tf 2.0.0。 然后一切都像魔术一样运作。 或者作为最后的手段,我能够训练第一个模型,添加和删除额外的层,重新编译,并使用相同的分发策略在同一个 python 执行中继续训练,但永远无法使用分发策略重新加载 tf.keras ModelCheckpoint在 tf 1.15.0 中。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.