简体   繁体   English

为什么带有 tf.where 的 tf.gradient 返回 None?

[英]Why tf.gradient with tf.where returns None?

I would like to do change input or/and output according to some condition.我想根据某些条件更改输入或/和输出。 But the grad is None.但毕业是无。

How can I get the gradient and keep the selection like the code below?如何获得渐变并保持选择,如下面的代码?

input1, input2 = ..., ...
output1, output2 = model(input1), model(input2)

input = tf.where(tf.less(output1, output2), input1, input2)
output = tf.where(tf.less(output1, output2), output1, output2)

grad, = tf.gradient(output, input)

I print the type of input and output, their dimensions are the same as input1/output1(input2/output2).我打印输入和输出的类型,它们的尺寸与输入1/输出1(输入2/输出2)相同。 And if I only compute tf.gradient(output1,input1), it is no problem.如果我只计算 tf.gradient(output1,input1),那就没问题了。 What's the difference between them?它们之间有什么区别?

You can just do this: 您可以这样做:

input1, input2 = ..., ...
output1, output2 = model(input1), model(input2)
mask = tf.less(output1, output2)
input = tf.where(mask, input1, input2)
output = tf.where(mask, output1, output2)
grad = tf.add(*tf.gradients(output, [input1, input2]))

tf.gradient will return two tensors with zeros on the places where the corresponding input has not been selected. tf.gradient将在尚未选择相应输入的位置返回两个零的张量。 Hence, grad will hold the correct aggregated gradient for input . 因此, grad将为input保留正确的聚合梯度。

Your original approach does not work because, from the point of view of TensorFlow, there is no dependency between input and output . 您的原始方法不起作用,因为从TensorFlow的角度来看, inputoutput之间没有依赖关系。 input is computed from input1 and input2 and output is also computed from input1 and input2 , but there is no path in the graph from input to output , so there is no gradient. input从计算input1input2output也被从计算input1input2 ,但在从图中没有路径inputoutput ,所以没有梯度。

in case someone is still stuck with NaN inputs, tf.where can be entirely replaced with:如果有人仍然坚持使用 NaN 输入,tf.where 可以完全替换为:

tf.minimum(tensor_having_nans, value_that_replaces_nans)

Also, tf.maximum works, and gradients don't truncate.此外, tf.maximum有效,并且渐变不会截断。 In case of 'inf', only tf.minimum works在 'inf' 的情况下,只有tf.minimum有效

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM