简体   繁体   English

如何从 Tensorflow 检查点将权重加载到 Keras 模型

[英]How to load_weights to a Keras model from a Tensorflow checkpoint

I have some python code to train a network using Tensorflow's TFRecords and Dataset APIs.我有一些 Python 代码可以使用 Tensorflow 的 TFRecords 和 Dataset API 来训练网络。 I have built the network using tf.Keras.layers, this being arguably the easiest and fastest way.我已经使用 tf.Keras.layers 构建了网络,这可以说是最简单和最快的方法。 The handy function model_to_estimator()方便的函数 model_to_estimator()

modelTF = tf.keras.estimator.model_to_estimator(
    keras_model=model,
    custom_objects=None,
    config=run_config,
    model_dir=checkPointDirectory
)

converts a Keras model to an estimator, which allows us to take advantage of the Dataset API nicely, and automatically save checkpoints to checkPointDirectory during training, and upon training completion.将 Keras 模型转换为估计器,这使我们能够很好地利用数据集 API,并在训练期间和训练完成后自动将检查点保存到 checkPointDirectory。 The estimator API presents some invaluable features, such as automatically distributing the workload over multiple GPUs, with, eg estimator API 提供了一些非常宝贵的功能,例如自动将工作负载分配到多个 GPU,例如

distribution = tf.contrib.distribute.MirroredStrategy()
run_config = tf.estimator.RunConfig(train_distribute=distribution)

Now for big models and lots of data, it is often useful to execute predictions after training using some form of saved model.现在对于大模型和大量数据,使用某种形式的保存模型在训练后执行预测通常很有用。 It seems that as of Tensorflow 1.10 (see https://github.com/tensorflow/tensorflow/issues/19295 ), a tf.keras.model object supports load_weights() from a Tensorflow checkpoint.似乎从 Tensorflow 1.10(参见https://github.com/tensorflow/tensorflow/issues/19295 )开始, tf.keras.model 对象支持来自 Tensorflow 检查点的 load_weights() 。 This is mentioned briefly in the Tensorflow docs, but not the Keras docs, and I can't find anyone showing an example of this. Tensorflow 文档中简要提到了这一点,但 Keras 文档中没有提到这一点,我找不到任何人展示这方面的示例。 After defining the model layers again in some new .py, I have tried在一些新的 .py 中再次定义模型层后,我尝试过

checkPointPath = os.path.join('.', 'tfCheckPoints', 'keras_model.ckpt.index')
model.load_weights(filepath=checkPointPath, by_name=False)

but this gives a NotImplementedError:但这给出了一个 NotImplementedError:

Restoring a name-based tf.train.Saver checkpoint using the object-based restore API. This mode uses global names to match variables, and so is somewhat fragile. It also adds new restore ops to the graph each time it is called when graph building. Prefer re-encoding training checkpoints in the object-based format: run save() on the object-based saver (the same one this message is coming from) and use that checkpoint in the future.

2018-10-01 14:24:49.912087:
Traceback (most recent call last):
  File "C:/Users/User/PycharmProjects/python/mercury.classifier reductions/V3.2/wikiTestv3.2/modelEvaluation3.2.py", line 141, in <module>
    model.load_weights(filepath=checkPointPath, by_name=False)
  File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow\python\keras\engine\network.py", line 1526, in load_weights
    checkpointable_utils.streaming_restore(status=status, session=session)
  File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow\python\training\checkpointable\util.py", line 880, in streaming_restore
    "Streaming restore not supported from name-based checkpoints. File a "
NotImplementedError: Streaming restore not supported from name-based checkpoints. File a feature request if this limitation bothers you.

I would like to do as suggested by the Warning and use the 'object-based saver' instead, but I haven't found a way to do this via a RunConfig passed to estimator.train().我想按照警告的建议去做,而是使用“基于对象的保护程序”,但我还没有找到通过传递给 estimator.train() 的 RunConfig 来做到这一点的方法。

So is there a better way to get the saved weights back into an estimator for use in prediction?那么有没有更好的方法将保存的权重返回到估计器中以用于预测? The github thread seems to suggest that this is already implemented (though based on the error, probably in a different way than I am attempting above). github 线程似乎表明这已经实现(尽管基于错误,可能与我上面尝试的方式不同)。 Has anyone successfully used load_weights() on a TF checkpoint?有没有人在 TF 检查点上成功使用过 load_weights() ? I haven't been able to find any tutorials/examples on how this can be done, so any help is appreciated.我一直无法找到有关如何完成此操作的任何教程/示例,因此不胜感激。

我不确定,但也许您可以将keras_model.ckpt.index更改为keras_model.ckpt进行测试。

You can create a separate graph, load your checkpoint normally and then transfer weights to your Keras model:您可以创建一个单独的图表,正常加载您的检查点,然后将权重转移到您的 Keras 模型:

_graph = tf.Graph()
_sess = tf.Session(graph=_graph)

tf.saved_model.load(_sess, ['serve'], '../tf1_save/')

_weights_all, _bias_all = [], []
with _graph.as_default():
  for idx, t_var in enumerate(tf.trainable_variables()):
    # substitue variable_scope with your scope
    if 'variable_scope/' not in t_var.name: break
    
    print(t_var.name)
    val = _sess.run(t_var)
    _weights_all.append(val) if idx % 2 == 0 else _bias_all.append(val)

for layer, (weight, bias) in enumerate(zip(_weights_all, _bias_all)):
  self.model.layers[layer].set_weights([np.array(weight), np.array(bias)])

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM