[英]TensorFlow 2.0 - @tf.function and tf.unstack TypeError
[英]Equality comparison does not work inside TensorFlow 2.0 tf.function()
在關於TensorFlow 2.0 AutoGraphs的討論之后,我一直在玩耍並注意到直接指定了不等式比較(例如>
和<
,而等式比較則使用tf.equal
表示。
這是一個示例進行演示。 此函數使用>
運算符,在調用時效果很好 :
@tf.function
def greater_than_zero(value):
return value > 0
greater_than_zero(tf.constant(1))
# <tf.Tensor: id=1377, shape=(), dtype=bool, numpy=True>
greater_than_zero(tf.constant(-1))
# <tf.Tensor: id=1380, shape=(), dtype=bool, numpy=False>
這是另一個使用相等比較但不起作用的函數:
@tf.function
def equal_to_zero(value):
return value == 0
equal_to_zero(tf.constant(1))
# <tf.Tensor: id=1389, shape=(), dtype=bool, numpy=False> # OK...
equal_to_zero(tf.constant(0))
# <tf.Tensor: id=1392, shape=(), dtype=bool, numpy=False> # WHAT?
如果我將==
相等比較更改為tf.equal
,它將起作用。
@tf.function
def equal_to_zero2(value):
return tf.equal(value, 0)
equal_to_zero2(tf.constant(0))
# <tf.Tensor: id=1402, shape=(), dtype=bool, numpy=True>
我的問題是:為什么在tf.function
函數中使用不相等比較運算符,而相等比較卻不起作用?
我在文章“分析tf.function來發現Autograph的優點和精妙之處”的第3部分中對此行為進行了分析(強烈建議閱讀所有3部分,以了解如何在用tf.function
裝飾它之前正確編寫一個函數- tf.function
為答案的底部)。
對於__eq__
和tf.equal
問題,答案是:
總之:在
__eq__
運算符(tf.Tensor
)已被重寫,但運營商不使用tf.equal
檢查張量相等,它只是為Python變量身份檢查(如果您熟悉Java編程語言,這就像在字符串對象上使用的==運算符一樣。 原因是tf.Tensor
對象需要可tf.Tensor
,因為它在Tensorflow代碼庫中的任何地方都用作dict對象的鍵。
對於其他所有運算符,答案是AutoGraph不會將Python運算符轉換為TensorFlow邏輯運算符。 在“自動圖形如何(不)轉換運算符”部分中,我展示了每個Python運算符都將轉換為始終表示為false的圖形表示形式。
實際上,以下示例將產生“ wat”作為輸出
@tf.function
def if_elif(a, b):
if a > b:
tf.print("a > b", a, b)
elif a == b:
tf.print("a == b", a, b)
elif a < b:
tf.print("a < b", a, b)
else:
tf.print("wat")
x = tf.constant(1)
if_elif(x,x)
實際上,AutoGraph無法將Python代碼轉換為圖形代碼。 我們必須僅使用TensorFlow原語來幫助它。 在這種情況下,您的代碼將按預期工作。
@tf.function
def if_elif(a, b):
if tf.math.greater(a, b):
tf.print("a > b", a, b)
elif tf.math.equal(a, b):
tf.print("a == b", a, b)
elif tf.math.less(a, b):
tf.print("a < b", a, b)
else:
tf.print("wat")
我在這里給出了所有三篇文章的鏈接,我想您會發現它們很有用:
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.