簡體   English   中英

ValueError:沒有為任何變量提供梯度 - Keras Tensorflow 2.0

[英]ValueError: No gradients provided for any variable - Keras Tensorflow 2.0

我正在嘗試在 TensorFlow 站點上遵循此示例,但它不起作用。

這是我的代碼:

import tensorflow as tf

def vectorize(vector_like):
    return tf.convert_to_tensor(vector_like)

def batchify(vector):
    '''Make a batch out of a single example'''
    return vectorize([vector])

data = [(batchify([0]), batchify([0, 0, 0])), (batchify([1]), batchify([0, 0, 0])), (batchify([2]), batchify([0, 0, 0]))]
num_hidden = 5
num_classes = 3

opt = tf.keras.optimizers.SGD(learning_rate=0.1)
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(num_hidden, activation='relu'))
model.add(tf.keras.layers.Dense(num_classes, activation='sigmoid'))
loss_fn = lambda: tf.keras.backend.cast(tf.keras.losses.mse(model(input), output), tf.float32)
var_list_fn = lambda: model.trainable_weights
for input, output in data:
    opt.minimize(loss_fn, var_list_fn)

有一段時間,我收到了關於具有錯誤數據類型(int 而不是 float)的損失函數的警告,這就是我將轉換添加到損失函數的原因。

我得到的不是網絡訓練,而是錯誤:

ValueError: 沒有為任何變量提供梯度:['sequential/dense/kernel:0', 'sequential/dense/bias:0', 'sequential/dense_1/kernel:0', 'sequential/dense_1/bias:0'] .

為什么梯度沒有通過? 我究竟做錯了什么?

如果要在 TF2 中操作漸變,則需要使用GradientTape 例如,以下作品。


opt = tf.keras.optimizers.SGD(learning_rate=0.1)
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(num_hidden, activation='relu'))
model.add(tf.keras.layers.Dense(num_classes, activation='sigmoid'))

with tf.GradientTape() as tape:
  loss = tf.keras.backend.mean(tf.keras.losses.mse(model(input),tf.cast(output, tf.float32)))

gradients = tape.gradient(loss, model.trainable_variables)
opt.apply_gradients(zip(gradients, model.trainable_variables))

編輯

您實際上可以通過進行以下更改來使您的示例工作。

  • 僅對輸出使用loss_fn而不是完整的loss_fn (注意我也在做一個mean因為我們優化了損失的平均值)

通過“工作”,我的意思是它不會抱怨。 但是您需要進一步調查以確保它按預期工作。

loss_fn = lambda: tf.keras.backend.mean(tf.keras.losses.mse(model(input), tf.cast(output, tf.float32)))
var_list_fn = lambda: model.trainable_weights

opt.minimize(loss_fn, var_list_fn)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM