簡體   English   中英

平等比較在TensorFlow 2.0 tf.function()中不起作用

[英]Equality comparison does not work inside TensorFlow 2.0 tf.function()

在關於TensorFlow 2.0 AutoGraphs的討論之后,我一直在玩耍並注意到直接指定了不等式比較(例如>< ,而等式比較則使用tf.equal表示。

這是一個示例進行演示。 此函數使用>運算符,在調用時效果很好

@tf.function
def greater_than_zero(value):
    return value > 0

greater_than_zero(tf.constant(1))
#  <tf.Tensor: id=1377, shape=(), dtype=bool, numpy=True>
greater_than_zero(tf.constant(-1))
# <tf.Tensor: id=1380, shape=(), dtype=bool, numpy=False>

這是另一個使用相等比較但不起作用的函數:

@tf.function
def equal_to_zero(value):
    return value == 0

equal_to_zero(tf.constant(1))
# <tf.Tensor: id=1389, shape=(), dtype=bool, numpy=False>  # OK...

equal_to_zero(tf.constant(0))
# <tf.Tensor: id=1392, shape=(), dtype=bool, numpy=False>  # WHAT?

如果我將==相等比較更改為tf.equal ,它將起作用。

@tf.function
def equal_to_zero2(value):
    return tf.equal(value, 0)

equal_to_zero2(tf.constant(0))
# <tf.Tensor: id=1402, shape=(), dtype=bool, numpy=True>

我的問題是:為什么在tf.function函數中使用不相等比較運算符,而相等比較卻不起作用?

我在文章“分析tf.function來發現Autograph的優點和精妙之處”的第3部分中對此行為進行了分析(強烈建議閱讀所有3部分,以了解如何在用tf.function裝飾它之前正確編寫一個函數- tf.function為答案的底部)。

對於__eq__tf.equal問題,答案是:

總之:在__eq__運算符( tf.Tensor )已被重寫,但運營商不使用tf.equal檢查張量相等,它只是為Python變量身份檢查(如果您熟悉Java編程語言,這就像在字符串對象上使用的==運算符一樣。 原因是tf.Tensor對象需要可tf.Tensor ,因為它在Tensorflow代碼庫中的任何地方都用作dict對象的鍵。

對於其他所有運算符,答案是AutoGraph不會將Python運算符轉換為TensorFlow邏輯運算符。 在“自動圖形如何(不)轉換運算符”部分中,我展示了每個Python運算符都將轉換為始終表示為false的圖形表示形式。

實際上,以下示例將產生“ wat”作為輸出

@tf.function
def if_elif(a, b):
  if a > b:
    tf.print("a > b", a, b)
  elif a == b:
    tf.print("a == b", a, b)
  elif a < b:
    tf.print("a < b", a, b)
  else:
    tf.print("wat")
x = tf.constant(1)
if_elif(x,x)

實際上,AutoGraph無法將Python代碼轉換為圖形代碼。 我們必須僅使用TensorFlow原語來幫助它。 在這種情況下,您的代碼將按預期工作。

@tf.function
def if_elif(a, b):
  if tf.math.greater(a, b):
    tf.print("a > b", a, b)
  elif tf.math.equal(a, b):
    tf.print("a == b", a, b)
  elif tf.math.less(a, b):
    tf.print("a < b", a, b)
  else:
    tf.print("wat")

我在這里給出了所有三篇文章的鏈接,我想您會發現它們很有用:

第1部分第2 部分第3部分

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM