簡體   English   中英

如何從可訓練的 tensorflow 變量轉換為不可訓練的 tensorflow 變量?

[英]How to convert from trainable tensorflow variables to not trainable tensorflow variable?

我需要恢復一個 DNN(VGG16Net),並使用遷移學習來構建另一個網絡。 所以在這里我需要將一些過濾器、偏差變量從可訓練的 tensorflow 變量轉換為不可訓練的變量(我使用的是原生 tensorflow 框架,而不是 keras 或任何更高級別的包)。

例如,為了從卷積層 4_1 獲取權重,我使用了conv4_3_filter=sess.graph.get_tensor_by_name('conv4_3/filter:0')但變量 ''conv4_3_filter'' 始終是可訓練的變量。 所以,在這里我試圖找到一種通用的方法來將任何 tensorflow 變量從可訓練轉換為不可訓練。 我該如何解決這個問題?

我認為不可能修改tf.Variable trainable屬性。 但是,有多種解決方法。

假設你有兩個變量:

import tensorflow as tf

v1 = tf.Variable(tf.random_normal([2, 2]), name='v1')
v2 = tf.Variable(tf.random_normal([2, 2]), name='v2')

當您使用tf.train.Optimizer類及其子類進行優化時,默認情況下它會從tf.GraphKeys.TRAINABLE_VARIABLES集合中獲取變量。 默認情況下,您使用trainable=True定義的每個變量都會添加到此集合中。 您可以做的是清除此集合並僅將您願意優化的那些變量附加到其中。 例如,如果我只想優化v1而不是v2

var_list = tf.trainable_variables()
print(var_list)
# [<tf.Variable 'v1:0' shape=(2, 2) dtype=float32_ref>,
#  <tf.Variable 'v2:0' shape=(2, 2) dtype=float32_ref>]

tf.get_default_graph().clear_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

cleared_var_list = tf.trainable_variables()
print(cleared_var_list)
# []

tf.add_to_collection(tf.GraphKeys.TRAINABLE_VARIABLES, var_list[0])

updated_var_list = tf.trainable_variables()
print(updated_var_list)
# [<tf.Variable 'v1:0' shape=(2, 2) dtype=float32_ref>]

另一種方法是使用優化器的var_list關鍵字參數並傳遞您想要在訓練期間(在執行train_op期間)更新的那些變量:

optimizer = tf.train.GradientDescentOptimizer(0.01)
train_op = optimizer.minimize(loss, var_list=[v1])

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM