[英]Why tf.gradient with tf.where returns None?
I would like to do change input or/and output according to some condition.我想根据某些条件更改输入或/和输出。 But the grad is None.但毕业是无。
How can I get the gradient and keep the selection like the code below?如何获得渐变并保持选择,如下面的代码?
input1, input2 = ..., ...
output1, output2 = model(input1), model(input2)
input = tf.where(tf.less(output1, output2), input1, input2)
output = tf.where(tf.less(output1, output2), output1, output2)
grad, = tf.gradient(output, input)
I print the type of input and output, their dimensions are the same as input1/output1(input2/output2).我打印输入和输出的类型,它们的尺寸与输入1/输出1(输入2/输出2)相同。 And if I only compute tf.gradient(output1,input1), it is no problem.如果我只计算 tf.gradient(output1,input1),那就没问题了。 What's the difference between them?它们之间有什么区别?
You can just do this: 您可以这样做:
input1, input2 = ..., ...
output1, output2 = model(input1), model(input2)
mask = tf.less(output1, output2)
input = tf.where(mask, input1, input2)
output = tf.where(mask, output1, output2)
grad = tf.add(*tf.gradients(output, [input1, input2]))
tf.gradient
will return two tensors with zeros on the places where the corresponding input has not been selected. tf.gradient
将在尚未选择相应输入的位置返回两个零的张量。 Hence, grad
will hold the correct aggregated gradient for input
. 因此, grad
将为input
保留正确的聚合梯度。
Your original approach does not work because, from the point of view of TensorFlow, there is no dependency between input
and output
. 您的原始方法不起作用,因为从TensorFlow的角度来看, input
和output
之间没有依赖关系。 input
is computed from input1
and input2
and output
is also computed from input1
and input2
, but there is no path in the graph from input
to output
, so there is no gradient. input
从计算input1
和input2
和output
也被从计算input1
和input2
,但在从图中没有路径input
到output
,所以没有梯度。
in case someone is still stuck with NaN inputs, tf.where can be entirely replaced with:如果有人仍然坚持使用 NaN 输入,tf.where 可以完全替换为:
tf.minimum(tensor_having_nans, value_that_replaces_nans)
Also, tf.maximum
works, and gradients don't truncate.此外, tf.maximum
有效,并且渐变不会截断。 In case of 'inf', only tf.minimum
works在 'inf' 的情况下,只有tf.minimum
有效
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.