簡體   English   中英

Tensorflow 2 中缺少(可訓練的)變量

[英]Missing (trainable) variables in Tensorflow 2

我在更大的代碼中遇到了這個問題。 我在下面的測試代碼中重現了它。 tensorflow 2.1 未完全列出可訓練變量。

import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
import numpy as np

class FooLayer(tf.keras.layers.Layer):
    def __init__(self, siz):
        super(FooLayer, self).__init__()
        self.siz = siz
        self.buildFoo(siz)

    def call(self, in_data):
        Foo0 = tf.multiply(in_data,self.FooTns0)
        FooList = []
        FooList.append(Foo0)
        for it in range(1,self.siz+1):
            tmp = tf.multiply(FooList[it-1],self.FooTns[it-1])
            FooList.append(tmp)
        return FooList[self.siz]

    def buildFoo(self,siz):
        self.FooTns0 = tf.Variable(1.0, name="TNS0")
        self.FooTns = []
        for it in range(0,self.siz):
            self.FooTns.append(tf.Variable(np.float32(it),
                name="TNS"+str(it+1)))
            self.add_weight("TNS"+str(it+1)) # Added after the first suggestion

class FooModel(tf.keras.Model):
    def __init__(self, siz):
        super(FooModel, self).__init__()
        self.flayer = FooLayer(siz)

    def call(self, in_data):
        return self.flayer(in_data)

model = FooModel(5)

for v in model.trainable_variables:
    print(v.name)

for v in model.variables:
    print(v.name)

x = np.arange(1.0,2.0,1.0)
x = x.astype(np.float32)

optimizer = tf.keras.optimizers.Adam(learning_rate=0.1)

with tf.GradientTape() as tape:
    y = model(x)
grads = tape.gradient(y, model.trainable_variables)

optimizer.apply_gradients(zip(grads, model.trainable_variables))

原廠output目前只有:

TNS0:0
TNS0:0

雖然預期的 output 列出了所有 6 個張量,“self.FooTns0”和“self.FooTns”。

第一個建議

在@Wathek LOUED 的第一個建議之后,我添加了self.add_weight("TNS"+str(it+1))行,並且 output 確實包括所有其他 TNS。 但是,梯度仍然沒有找到它們並給出錯誤消息,

WARNING:tensorflow:Gradients do not exist for variables ['TNS1:0', 'TNS2:0', 'TNS3:0', 'TNS4:0', 'TNS5:0'] when minimizing the loss.

使用add_weight方法代替 class Layer的一部分怎么樣?

for it in range(0,self.siz):
  self.add_weight("TNS"+str(it+1))

事實證明這是 TF2 中的一個錯誤,由 TF 組確認( https://github.com/tensorflow/tensorflow/issues/38211

有一個臨時的解決方法,它顯式地將buildFoo function 中的張量列表返回到__init__ 例如見鏈接

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM