[英]How to convert from trainable tensorflow variables to not trainable tensorflow variable?
[英]Missing (trainable) variables in Tensorflow 2
我在更大的代碼中遇到了這個問題。 我在下面的測試代碼中重現了它。 tensorflow 2.1 未完全列出可訓練變量。
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
import numpy as np
class FooLayer(tf.keras.layers.Layer):
def __init__(self, siz):
super(FooLayer, self).__init__()
self.siz = siz
self.buildFoo(siz)
def call(self, in_data):
Foo0 = tf.multiply(in_data,self.FooTns0)
FooList = []
FooList.append(Foo0)
for it in range(1,self.siz+1):
tmp = tf.multiply(FooList[it-1],self.FooTns[it-1])
FooList.append(tmp)
return FooList[self.siz]
def buildFoo(self,siz):
self.FooTns0 = tf.Variable(1.0, name="TNS0")
self.FooTns = []
for it in range(0,self.siz):
self.FooTns.append(tf.Variable(np.float32(it),
name="TNS"+str(it+1)))
self.add_weight("TNS"+str(it+1)) # Added after the first suggestion
class FooModel(tf.keras.Model):
def __init__(self, siz):
super(FooModel, self).__init__()
self.flayer = FooLayer(siz)
def call(self, in_data):
return self.flayer(in_data)
model = FooModel(5)
for v in model.trainable_variables:
print(v.name)
for v in model.variables:
print(v.name)
x = np.arange(1.0,2.0,1.0)
x = x.astype(np.float32)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.1)
with tf.GradientTape() as tape:
y = model(x)
grads = tape.gradient(y, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
原廠output目前只有:
TNS0:0
TNS0:0
雖然預期的 output 列出了所有 6 個張量,“self.FooTns0”和“self.FooTns”。
第一個建議
在@Wathek LOUED 的第一個建議之后,我添加了self.add_weight("TNS"+str(it+1))
行,並且 output 確實包括所有其他 TNS。 但是,梯度仍然沒有找到它們並給出錯誤消息,
WARNING:tensorflow:Gradients do not exist for variables ['TNS1:0', 'TNS2:0', 'TNS3:0', 'TNS4:0', 'TNS5:0'] when minimizing the loss.
使用add_weight
方法代替 class Layer
的一部分怎么樣?
for it in range(0,self.siz):
self.add_weight("TNS"+str(it+1))
事實證明這是 TF2 中的一個錯誤,由 TF 組確認( https://github.com/tensorflow/tensorflow/issues/38211 )
有一個臨時的解決方法,它顯式地將buildFoo
function 中的張量列表返回到__init__
。 例如見鏈接。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.