[英]Variable tf.Variable has 'None' for gradient in TensorFlow Probability
I'm having trouble constructing a basic BNN in TFP.我在 TFP 中构建基本 BNN 时遇到问题。 I'm new to TFP and BNNs in general, so I apologize if I've missed something simple.一般来说,我是 TFP 和 BNN 的新手,所以如果我错过了一些简单的东西,我深表歉意。
I can train a basic NN in Tensorflow by doing the following:我可以通过执行以下操作在 Tensorflow 中训练一个基本的神经网络:
model = keras.Sequential([
keras.layers.Dense(units=100, activation='relu'),
keras.layers.Dense(units=50, activation='relu'),
keras.layers.Dense(units=5, activation='softmax')
])
model.compile(optimizer=optimizer,
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
history = model.fit(
training_data.repeat(),
epochs=100,
steps_per_epoch=(X_train.shape[0]//1024),
validation_data=test_data.repeat(),
validation_steps=2
)
However, I have trouble when trying to implement a similar architecture with tfp DenseFlipout layers:但是,我在尝试使用 tfp DenseFlipout 层实现类似的架构时遇到了麻烦:
model = keras.Sequential([
tfp.layers.DenseFlipout(units=100, activation='relu'),
tfp.layers.DenseFlipout(units=10, activation='relu'),
tfp.layers.DenseFlipout(units=5, activation='softmax')
])
model.compile(optimizer=optimizer,
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
history = model.fit(
training_data.repeat(),
epochs=100,
steps_per_epoch=(X_train.shape[0]//1024),
validation_data=test_data.repeat(),
validation_steps=2
)
I get the following Value error:我收到以下值错误:
ValueError:
Variable <tf.Variable 'sequential_11/dense_flipout_15/kernel_posterior_loc:0'
shape=(175, 100) dtype=float32> has `None` for gradient.
Please make sure that all of your ops have a gradient defined (i.e. are differentiable).
Common ops without gradient: K.argmax, K.round, K.eval.
I've done some googling, and have looked around the TFP docs, but am at a loss so thought I would share the issue.我已经做了一些谷歌搜索,并查看了 TFP 文档,但我不知所措,所以我想我会分享这个问题。 Have I missed something obvious?我错过了一些明显的东西吗?
Thanks in advance.提前致谢。
I expect it's because you're using TensorFlow 2, are you?我猜是因为您使用的是 TensorFlow 2,是吗? It isn't fully supported yet.它还没有得到完全支持。 If so, downgrading to 1.14 should get it working.如果是这样,降级到 1.14 应该可以让它工作。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.