繁体   English   中英

如何从可训练的 tensorflow 变量转换为不可训练的 tensorflow 变量?

[英]How to convert from trainable tensorflow variables to not trainable tensorflow variable?

我需要恢复一个 DNN(VGG16Net),并使用迁移学习来构建另一个网络。 所以在这里我需要将一些过滤器、偏差变量从可训练的 tensorflow 变量转换为不可训练的变量(我使用的是原生 tensorflow 框架,而不是 keras 或任何更高级别的包)。

例如,为了从卷积层 4_1 获取权重,我使用了conv4_3_filter=sess.graph.get_tensor_by_name('conv4_3/filter:0')但变量 ''conv4_3_filter'' 始终是可训练的变量。 所以,在这里我试图找到一种通用的方法来将任何 tensorflow 变量从可训练转换为不可训练。 我该如何解决这个问题?

我认为不可能修改tf.Variable trainable属性。 但是,有多种解决方法。

假设你有两个变量:

import tensorflow as tf

v1 = tf.Variable(tf.random_normal([2, 2]), name='v1')
v2 = tf.Variable(tf.random_normal([2, 2]), name='v2')

当您使用tf.train.Optimizer类及其子类进行优化时,默认情况下它会从tf.GraphKeys.TRAINABLE_VARIABLES集合中获取变量。 默认情况下,您使用trainable=True定义的每个变量都会添加到此集合中。 您可以做的是清除此集合并仅将您愿意优化的那些变量附加到其中。 例如,如果我只想优化v1而不是v2

var_list = tf.trainable_variables()
print(var_list)
# [<tf.Variable 'v1:0' shape=(2, 2) dtype=float32_ref>,
#  <tf.Variable 'v2:0' shape=(2, 2) dtype=float32_ref>]

tf.get_default_graph().clear_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

cleared_var_list = tf.trainable_variables()
print(cleared_var_list)
# []

tf.add_to_collection(tf.GraphKeys.TRAINABLE_VARIABLES, var_list[0])

updated_var_list = tf.trainable_variables()
print(updated_var_list)
# [<tf.Variable 'v1:0' shape=(2, 2) dtype=float32_ref>]

另一种方法是使用优化器的var_list关键字参数并传递您想要在训练期间(在执行train_op期间)更新的那些变量:

optimizer = tf.train.GradientDescentOptimizer(0.01)
train_op = optimizer.minimize(loss, var_list=[v1])

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM