[英]How to convert from trainable tensorflow variables to not trainable tensorflow variable?
我需要恢復一個 DNN(VGG16Net),並使用遷移學習來構建另一個網絡。 所以在這里我需要將一些過濾器、偏差變量從可訓練的 tensorflow 變量轉換為不可訓練的變量(我使用的是原生 tensorflow 框架,而不是 keras 或任何更高級別的包)。
例如,為了從卷積層 4_1 獲取權重,我使用了conv4_3_filter=sess.graph.get_tensor_by_name('conv4_3/filter:0')
但變量 ''conv4_3_filter'' 始終是可訓練的變量。 所以,在這里我試圖找到一種通用的方法來將任何 tensorflow 變量從可訓練轉換為不可訓練。 我該如何解決這個問題?
我認為不可能修改tf.Variable
trainable
屬性。 但是,有多種解決方法。
假設你有兩個變量:
import tensorflow as tf
v1 = tf.Variable(tf.random_normal([2, 2]), name='v1')
v2 = tf.Variable(tf.random_normal([2, 2]), name='v2')
當您使用tf.train.Optimizer
類及其子類進行優化時,默認情況下它會從tf.GraphKeys.TRAINABLE_VARIABLES
集合中獲取變量。 默認情況下,您使用trainable=True
定義的每個變量都會添加到此集合中。 您可以做的是清除此集合並僅將您願意優化的那些變量附加到其中。 例如,如果我只想優化v1
而不是v2
:
var_list = tf.trainable_variables()
print(var_list)
# [<tf.Variable 'v1:0' shape=(2, 2) dtype=float32_ref>,
# <tf.Variable 'v2:0' shape=(2, 2) dtype=float32_ref>]
tf.get_default_graph().clear_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
cleared_var_list = tf.trainable_variables()
print(cleared_var_list)
# []
tf.add_to_collection(tf.GraphKeys.TRAINABLE_VARIABLES, var_list[0])
updated_var_list = tf.trainable_variables()
print(updated_var_list)
# [<tf.Variable 'v1:0' shape=(2, 2) dtype=float32_ref>]
另一種方法是使用優化器的var_list
關鍵字參數並傳遞您想要在訓練期間(在執行train_op
期間)更新的那些變量:
optimizer = tf.train.GradientDescentOptimizer(0.01)
train_op = optimizer.minimize(loss, var_list=[v1])
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.