简体   繁体   中英

How to resume from a checkpoint when using Horovod with tf.keras?

Note: I'm using TF 2.1.0 and the tf.keras API. I've experienced the below issue with all Horovod versions between 0.18 and 0.19.2.

Are we supposed to call hvd.load_model() on all ranks when resuming from a tf.keras h5 checkpoint, or are we only supposed to call it on rank 0 and let the BroadcastGlobalVariablesCallback callback share these weights with the other workers? Is approach 1 incorrect/invalid, in that it will mess up training or produce different results than approach 2?

I'm currently training a ResNet-based model with some BatchNorm layers, and if we only try to load the model on the first rank (and build/compile the model on the other ranks), we get a stalled tensor issue ( https://github.com/horovod/horovod/issues/1271 ). However, if we call hvd.load_model on all ranks when resuming, training starts resuming normally but it seems to immediately diverge, so I was confused as to whether loading the checkpoint model on all ranks (with hvd.load_model ) can somehow cause training to diverge? But at the same time, we're unable to only load it on rank 0 because of https://github.com/horovod/horovod/issues/1271 , causing Batch Norm to hang in horovod. Has anyone been able to successfully call hvd.load_model only on rank 0 when using BatchNorm tf.keras layers? Can someone please provide some tips here?

Thanks!

According to this: https://github.com/horovod/horovod/issues/120 , this is the solution:

You should also be able to specify optimizer via custom object:
model = keras.models.load_model('file.h5', custom_objects={
    'Adam': lambda **kwargs: hvd.DistributedOptimizer(keras.optimizers.Adam(**kwargs))
})

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM