简体   繁体   English

平等比较在TensorFlow 2.0 tf.function()中不起作用

[英]Equality comparison does not work inside TensorFlow 2.0 tf.function()

Following the discussion on TensorFlow 2.0 AutoGraphs , I've been playing around and noticed that inequality comparisons such as > and < are specified directly, whereas equality comparisons are represented using tf.equal . 在关于TensorFlow 2.0 AutoGraphs的讨论之后,我一直在玩耍并注意到直接指定了不等式比较(例如>< ,而等式比较则使用tf.equal表示。

Here's an example to demonstrate. 这是一个示例进行演示。 This function uses > operator and works well when called: 此函数使用>运算符,在调用时效果很好

@tf.function
def greater_than_zero(value):
    return value > 0

greater_than_zero(tf.constant(1))
#  <tf.Tensor: id=1377, shape=(), dtype=bool, numpy=True>
greater_than_zero(tf.constant(-1))
# <tf.Tensor: id=1380, shape=(), dtype=bool, numpy=False>

Here is another function that uses equality comparison, but does not work : 这是另一个使用相等比较但不起作用的函数:

@tf.function
def equal_to_zero(value):
    return value == 0

equal_to_zero(tf.constant(1))
# <tf.Tensor: id=1389, shape=(), dtype=bool, numpy=False>  # OK...

equal_to_zero(tf.constant(0))
# <tf.Tensor: id=1392, shape=(), dtype=bool, numpy=False>  # WHAT?

If I change the == equality comparison to tf.equal , it will work. 如果我将==相等比较更改为tf.equal ,它将起作用。

@tf.function
def equal_to_zero2(value):
    return tf.equal(value, 0)

equal_to_zero2(tf.constant(0))
# <tf.Tensor: id=1402, shape=(), dtype=bool, numpy=True>

My question is: Why does using inequality comparison operators work inside tf.function functions, whereas equality comparisons do not? 我的问题是:为什么在tf.function函数中使用不相等比较运算符,而相等比较却不起作用?

I analyzed this behavior in part 3 of the article "Analysing tf.function to discover Autograph strengths and subtleties" (and I highly recommend reading all the 3 parts to understand how to correctly write a function before decorating it with tf.function - links at the bottom of the answer). 我在文章“分析tf.function来发现Autograph的优点和精妙之处”的第3部分中对此行为进行了分析(强烈建议阅读所有3部分,以了解如何在用tf.function装饰它之前正确编写一个函数- tf.function为答案的底部)。

For the __eq__ and tf.equal question, the answer is: 对于__eq__tf.equal问题,答案是:

In short: the __eq__ operator (for tf.Tensor ) has been overridden, but the operator does not use tf.equal to check for the Tensor equality, it just checks for the Python variable identity (if you are familiar with the Java programming language, this is precisely like the == operator used on string objects). 总之:在__eq__运算符( tf.Tensor )已被重写,但运营商不使用tf.equal检查张量相等,它只是为Python变量身份检查(如果您熟悉Java编程语言,这就像在字符串对象上使用的==运算符一样。 The reason is that the tf.Tensor object needs to be hashable since it is used everywhere in the Tensorflow codebase as key for dict objects. 原因是tf.Tensor对象需要可tf.Tensor ,因为它在Tensorflow代码库中的任何地方都用作dict对象的键。

While for all the other operators, the answer is that AutoGraph doesn't convert Python operators to TensorFlow logical operators. 对于其他所有运算符,答案是AutoGraph不会将Python运算符转换为TensorFlow逻辑运算符。 In the section How AutoGraph (don't) converts the operators I showed that every Python operator gets converted to a graph representation that is always evaluated as false. 在“自动图形如何(不)转换运算符”部分中,我展示了每个Python运算符都将转换为始终表示为false的图形表示形式。

In fact, the following example produces as output "wat" 实际上,以下示例将产生“ wat”作为输出

@tf.function
def if_elif(a, b):
  if a > b:
    tf.print("a > b", a, b)
  elif a == b:
    tf.print("a == b", a, b)
  elif a < b:
    tf.print("a < b", a, b)
  else:
    tf.print("wat")
x = tf.constant(1)
if_elif(x,x)

In practice, AutoGraph is unable to convert Python code to graph code; 实际上,AutoGraph无法将Python代码转换为图形代码。 we have to help it using only the TensorFlow primitives. 我们必须仅使用TensorFlow原语来帮助它。 In that case, your code will work as you expect. 在这种情况下,您的代码将按预期工作。

@tf.function
def if_elif(a, b):
  if tf.math.greater(a, b):
    tf.print("a > b", a, b)
  elif tf.math.equal(a, b):
    tf.print("a == b", a, b)
  elif tf.math.less(a, b):
    tf.print("a < b", a, b)
  else:
    tf.print("wat")

I let here the links to all the three articles, I guess you'll find them usefult: 我在这里给出了所有三篇文章的链接,我想您会发现它们很有用:

part 1 , part 2 , part 3 第1部分第2 部分第3部分

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM