简体   繁体   English

tensorflow 2.x 代码中的条件报告错误

[英]Condition in tensorflow 2.x code reports error

migrating to tensorflow 2.x.迁移到tensorflow 2.x. Win 10, tf version is 2.3.1. Win 10,tf版本为2.3.1。 Basically,基本上,

import tensorflow as tf

def do_nothing(x, y):
    m, n = x.shape
    if m==n:
        return x, y
    else:
        raise Exception('should never arrive here')

xys = [[tf.eye(2), tf.eye(3)], 
       [tf.eye(4), tf.eye(5)],]

@tf.function
def foo():
    return [do_nothing(x, y) for (x, y) in xys]
    
ans = foo()

works.作品。 Then I just change condition m==m to tf.equal(m,n) as然后我只是将条件m==m更改为tf.equal(m,n)

import tensorflow as tf

def do_nothing(x, y):
    m, n = x.shape
    if tf.equal(m, n):
        return x, y
    else:
        raise Exception('should never arrive here')

xys = [[tf.eye(2), tf.eye(3)], 
       [tf.eye(4), tf.eye(5)],]

@tf.function
def foo():
    return [do_nothing(x, y) for (x, y) in xys]
    
ans = foo()

the coder no longer works.编码器不再工作。 Got really puzzled.真的很纳闷。 A bug or what?一个错误还是什么?

I tried additional experiments to reproduce the issue with less code.我尝试了额外的实验来用更少的代码重现这个问题。 It looks like that if you use things like tf.equal or tf.greater , then the if and else clauses must return tensors of the same type and sizes.看起来如果你使用tf.equaltf.greater之类的东西,那么ifelse子句必须返回相同类型和大小的张量。 See the code below.请参阅下面的代码。

import tensorflow as tf

#this piece works
@tf.function  
def foo1(x):
    if tf.greater(len(x), 0):
        return True
    else:
        return False
print(foo1(tf.zeros([1])))
print(foo1(tf.zeros([0])))

#this piece works too
@tf.function
def foo2(x):
    if len(x)>0: 
        return True
    else:
        raise Exception()
print(foo2(tf.zeros([1])))

#this piece no long works
@tf.function
def foo3(x):
    if tf.greater(len(x), 0):
        return True
    else:
        raise Exception()
print(foo3(tf.zeros([1])))

I think the reason is because tf returns a Tensor of type bool, not a simple Bool.我认为原因是因为 tf 返回了一个 bool 类型的张量,而不是一个简单的 Bool。 http://tensorflow.biotecan.com/python/Python_1.8/tensorflow.google.cn/api_docs/python/tf/equal.html http://tensorflow.biotecan.com/python/Python_1.8/tensorflow.google.cn/api_docs/python/tf/equal.html

Referring to the test I did in Google colab:参考我在 Google colab 中所做的测试:

https://colab.research.google.com/drive/1sR99ScE-IDsWz0rNCH6VsWVclw1wz5oE#scrollTo=FsMqxbpnJ-Xg&line=1&uniqifier=1 https://colab.research.google.com/drive/1sR99ScE-IDsWz0rNCH6VsWVclw1wz5oE#scrollTo=FsMqxbpnJ-Xg&line=1&uniqifier=1

import tensorflow as tf

def do_nothing(x, y):
    m, n = x.shape
    print(x)
    print(y)
    print(m,n)
    print(m==m)
    print(n==n)
    print(m==n)
    print(tf.equal(m,m))
    print(tf.equal(n,n))
    print(tf.equal(m,n))
    if tf.equal(m, n):
        return x, y
    else:
        raise Exception('should never arrive here')

xys = [[tf.eye(2), tf.eye(3)], 
       [tf.eye(4), tf.eye(5)],]

@tf.function
def foo():
    return [do_nothing(x, y) for (x, y) in xys]
    
ans = foo()



tf.Tensor(
[[1. 0.]
 [0. 1.]], shape=(2, 2), dtype=float32)
tf.Tensor(
[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]], shape=(3, 3), dtype=float32)
2 2
True
True
True
Tensor("Equal:0", shape=(), dtype=bool)
Tensor("Equal_1:0", shape=(), dtype=bool)
Tensor("Equal_2:0", shape=(), dtype=bool)

---------------------------------------------------------------------------

Exception                                 Traceback (most recent call last)

<ipython-input-12-24121e0806b4> in <module>()
     24     return [do_nothing(x, y) for (x, y) in xys]
     25 
---> 26 ans = foo()

8 frames

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    975           except Exception as e:  # pylint:disable=broad-except
    976             if hasattr(e, "ag_error_metadata"):
--> 977               raise e.ag_error_metadata.to_exception(e)
    978             else:
    979               raise

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM