简体   繁体   English

tf.shape() 在张量流中得到错误的形状

[英]tf.shape() get wrong shape in tensorflow

I define a tensor like this:我定义了一个这样的张量:

x = tf.get_variable("x", [100])

But when I try to print shape of tensor :但是当我尝试打印张量的形状时:

print( tf.shape(x) )

I get Tensor("Shape:0", shape=(1,), dtype=int32) , why the result of output should not be shape=(100)我得到Tensor("Shape:0", shape=(1,), dtype=int32) ,为什么输出的结果不应该是 shape=(100)

tf.shape(input, name=None) returns a 1-D integer tensor representing the shape of input. tf.shape(input, name=None)返回一个表示输入形状的一维整数张量。

You're looking for: x.get_shape() that returns the TensorShape of the x variable.您正在寻找: x.get_shape()返回x变量的TensorShape

Update: I wrote an article to clarify the dynamic/static shapes in Tensorflow because of this answer: https://pgaleone.eu/tensorflow/2018/07/28/understanding-tensorflow-tensors-shape-static-dynamic/更新:由于这个答案,我写了一篇文章来阐明 Tensorflow 中的动态/静态形状: https ://pgaleone.eu/tensorflow/2018/07/28/understanding-tensorflow-tensors-shape-static-dynamic/

Clarification:澄清:

tf.shape(x) creates an op and returns an object which stands for the output of the constructed op, which is what you are printing currently. tf.shape(x) 创建一个操作并返回一个代表构造操作的输出的对象,这就是您当前正在打印的内容。 To get the shape, run the operation in a session:要获取形状,请在会话中运行操作:

matA = tf.constant([[7, 8], [9, 10]])
shapeOp = tf.shape(matA) 
print(shapeOp) #Tensor("Shape:0", shape=(2,), dtype=int32)
with tf.Session() as sess:
   print(sess.run(shapeOp)) #[2 2]

credit: After looking at the above answer, I saw the answer to tf.rank function in Tensorflow which I found more helpful and I have tried rephrasing it here.信用:在查看上述答案后,我在 Tensorflow 中看到了tf.rank 函数的答案,我发现它更有帮助,我尝试在此处重新措辞

Just a quick example, to make things clear:只是一个简单的例子,让事情清楚:

a = tf.Variable(tf.zeros(shape=(2, 3, 4)))
print('-'*60)
print("v1", tf.shape(a))
print('-'*60)
print("v2", a.get_shape())
print('-'*60)
with tf.Session() as sess:
    print("v3", sess.run(tf.shape(a)))
print('-'*60)
print("v4",a.shape)

Output will be:输出将是:

------------------------------------------------------------
v1 Tensor("Shape:0", shape=(3,), dtype=int32)
------------------------------------------------------------
v2 (2, 3, 4)
------------------------------------------------------------
v3 [2 3 4]
------------------------------------------------------------
v4 (2, 3, 4)

Also this should be helpful: How to understand static shape and dynamic shape in TensorFlow?这也应该有帮助: 如何理解 TensorFlow 中的静态形状和动态形状?

Similar question is nicely explained in TF FAQ : TF FAQ 中很好地解释了类似的问题:

In TensorFlow, a tensor has both a static (inferred) shape and a dynamic (true) shape.在 TensorFlow 中,张量具有静态(推断)形状和动态(真实)形状。 The static shape can be read using the tf.Tensor.get_shape method: this shape is inferred from the operations that were used to create the tensor, and may be partially complete.可以使用tf.Tensor.get_shape方法读取静态形状:此形状是从用于创建张量的操作中推断出来的,可能是部分完整的。 If the static shape is not fully defined, the dynamic shape of a Tensor t can be determined by evaluating tf.shape(t) .如果静态形状没有完全定义,张量 t 的动态形状可以通过评估tf.shape(t)来确定。

So tf.shape() returns you a tensor, will always have a size of shape=(N,) , and can be calculated in a session:所以tf.shape()返回一个张量,大小总是shape=(N,) ,并且可以在会话中计算:

a = tf.Variable(tf.zeros(shape=(2, 3, 4)))
with tf.Session() as sess:
    print sess.run(tf.shape(a))

On the other hand you can extract the static shape by using x.get_shape().as_list() and this can be calculated anywhere.另一方面,您可以使用x.get_shape().as_list()提取静态形状,这可以在任何地方计算。

Simply, use tensor.shape to get the static shape :简单地说,使用tensor.shape来获得静态形状

In [102]: a = tf.placeholder(tf.float32, [None, 128])

# returns [None, 128]
In [103]: a.shape.as_list()
Out[103]: [None, 128]

Whereas to get the dynamic shape , use tf.shape() :而要获得动态形状,请使用tf.shape()

dynamic_shape = tf.shape(a)

You can also get the shape as you'd in NumPy with your_tensor.shape as in the following example.您还可以像在 NumPy 中一样使用your_tensor.shape获取形状,如下例所示。

In [11]: tensr = tf.constant([[1, 2, 3, 4, 5], [2, 3, 4, 5, 6]])

In [12]: tensr.shape
Out[12]: TensorShape([Dimension(2), Dimension(5)])

In [13]: list(tensr.shape)
Out[13]: [Dimension(2), Dimension(5)]

In [16]: print(tensr.shape)
(2, 5)

Also, this example, for tensors which can be eval uated.此外,这个例子,对于可以eval张量。

In [33]: tf.shape(tensr).eval().tolist()
Out[33]: [2, 5]

Tensorflow 2.0 Compatible Answer : Tensorflow 2.x (>= 2.0) compatible answer for nessuno's solution is shown below: Tensorflow 2.0 兼容答案Tensorflow 2.x (>= 2.0)解决方案的Tensorflow 2.x (>= 2.0)兼容答案如下所示:

x = tf.compat.v1.get_variable("x", [100])

print(x.get_shape())

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM