簡體   English   中英

如何從 Tensorflow 檢查點將權重加載到 Keras 模型

[英]How to load_weights to a Keras model from a Tensorflow checkpoint

我有一些 Python 代碼可以使用 Tensorflow 的 TFRecords 和 Dataset API 來訓練網絡。 我已經使用 tf.Keras.layers 構建了網絡,這可以說是最簡單和最快的方法。 方便的函數 model_to_estimator()

modelTF = tf.keras.estimator.model_to_estimator(
    keras_model=model,
    custom_objects=None,
    config=run_config,
    model_dir=checkPointDirectory
)

將 Keras 模型轉換為估計器,這使我們能夠很好地利用數據集 API,並在訓練期間和訓練完成后自動將檢查點保存到 checkPointDirectory。 estimator API 提供了一些非常寶貴的功能,例如自動將工作負載分配到多個 GPU,例如

distribution = tf.contrib.distribute.MirroredStrategy()
run_config = tf.estimator.RunConfig(train_distribute=distribution)

現在對於大模型和大量數據,使用某種形式的保存模型在訓練后執行預測通常很有用。 似乎從 Tensorflow 1.10(參見https://github.com/tensorflow/tensorflow/issues/19295 )開始, tf.keras.model 對象支持來自 Tensorflow 檢查點的 load_weights() 。 Tensorflow 文檔中簡要提到了這一點,但 Keras 文檔中沒有提到這一點,我找不到任何人展示這方面的示例。 在一些新的 .py 中再次定義模型層后,我嘗試過

checkPointPath = os.path.join('.', 'tfCheckPoints', 'keras_model.ckpt.index')
model.load_weights(filepath=checkPointPath, by_name=False)

但這給出了一個 NotImplementedError:

Restoring a name-based tf.train.Saver checkpoint using the object-based restore API. This mode uses global names to match variables, and so is somewhat fragile. It also adds new restore ops to the graph each time it is called when graph building. Prefer re-encoding training checkpoints in the object-based format: run save() on the object-based saver (the same one this message is coming from) and use that checkpoint in the future.

2018-10-01 14:24:49.912087:
Traceback (most recent call last):
  File "C:/Users/User/PycharmProjects/python/mercury.classifier reductions/V3.2/wikiTestv3.2/modelEvaluation3.2.py", line 141, in <module>
    model.load_weights(filepath=checkPointPath, by_name=False)
  File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow\python\keras\engine\network.py", line 1526, in load_weights
    checkpointable_utils.streaming_restore(status=status, session=session)
  File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow\python\training\checkpointable\util.py", line 880, in streaming_restore
    "Streaming restore not supported from name-based checkpoints. File a "
NotImplementedError: Streaming restore not supported from name-based checkpoints. File a feature request if this limitation bothers you.

我想按照警告的建議去做,而是使用“基於對象的保護程序”,但我還沒有找到通過傳遞給 estimator.train() 的 RunConfig 來做到這一點的方法。

那么有沒有更好的方法將保存的權重返回到估計器中以用於預測? github 線程似乎表明這已經實現(盡管基於錯誤,可能與我上面嘗試的方式不同)。 有沒有人在 TF 檢查點上成功使用過 load_weights() ? 我一直無法找到有關如何完成此操作的任何教程/示例,因此不勝感激。

我不確定,但也許您可以將keras_model.ckpt.index更改為keras_model.ckpt進行測試。

您可以創建一個單獨的圖表,正常加載您的檢查點,然后將權重轉移到您的 Keras 模型:

_graph = tf.Graph()
_sess = tf.Session(graph=_graph)

tf.saved_model.load(_sess, ['serve'], '../tf1_save/')

_weights_all, _bias_all = [], []
with _graph.as_default():
  for idx, t_var in enumerate(tf.trainable_variables()):
    # substitue variable_scope with your scope
    if 'variable_scope/' not in t_var.name: break
    
    print(t_var.name)
    val = _sess.run(t_var)
    _weights_all.append(val) if idx % 2 == 0 else _bias_all.append(val)

for layer, (weight, bias) in enumerate(zip(_weights_all, _bias_all)):
  self.model.layers[layer].set_weights([np.array(weight), np.array(bias)])

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM