[英]tf.shape() get wrong shape in tensorflow
I define a tensor like this:我定义了一个这样的张量:
x = tf.get_variable("x", [100])
But when I try to print shape of tensor :但是当我尝试打印张量的形状时:
print( tf.shape(x) )
I get Tensor("Shape:0", shape=(1,), dtype=int32) , why the result of output should not be shape=(100)我得到Tensor("Shape:0", shape=(1,), dtype=int32) ,为什么输出的结果不应该是 shape=(100)
tf.shape(input, name=None) returns a 1-D integer tensor representing the shape of input. tf.shape(input, name=None)返回一个表示输入形状的一维整数张量。
You're looking for: x.get_shape()
that returns the TensorShape
of the x
variable.您正在寻找:
x.get_shape()
返回x
变量的TensorShape
。
Update: I wrote an article to clarify the dynamic/static shapes in Tensorflow because of this answer: https://pgaleone.eu/tensorflow/2018/07/28/understanding-tensorflow-tensors-shape-static-dynamic/更新:由于这个答案,我写了一篇文章来阐明 Tensorflow 中的动态/静态形状: https ://pgaleone.eu/tensorflow/2018/07/28/understanding-tensorflow-tensors-shape-static-dynamic/
Clarification:澄清:
tf.shape(x) creates an op and returns an object which stands for the output of the constructed op, which is what you are printing currently. tf.shape(x) 创建一个操作并返回一个代表构造操作的输出的对象,这就是您当前正在打印的内容。 To get the shape, run the operation in a session:
要获取形状,请在会话中运行操作:
matA = tf.constant([[7, 8], [9, 10]])
shapeOp = tf.shape(matA)
print(shapeOp) #Tensor("Shape:0", shape=(2,), dtype=int32)
with tf.Session() as sess:
print(sess.run(shapeOp)) #[2 2]
credit: After looking at the above answer, I saw the answer to tf.rank function in Tensorflow which I found more helpful and I have tried rephrasing it here.信用:在查看上述答案后,我在 Tensorflow 中看到了tf.rank 函数的答案,我发现它更有帮助,我尝试在此处重新措辞。
Just a quick example, to make things clear:只是一个简单的例子,让事情清楚:
a = tf.Variable(tf.zeros(shape=(2, 3, 4)))
print('-'*60)
print("v1", tf.shape(a))
print('-'*60)
print("v2", a.get_shape())
print('-'*60)
with tf.Session() as sess:
print("v3", sess.run(tf.shape(a)))
print('-'*60)
print("v4",a.shape)
Output will be:输出将是:
------------------------------------------------------------
v1 Tensor("Shape:0", shape=(3,), dtype=int32)
------------------------------------------------------------
v2 (2, 3, 4)
------------------------------------------------------------
v3 [2 3 4]
------------------------------------------------------------
v4 (2, 3, 4)
Also this should be helpful: How to understand static shape and dynamic shape in TensorFlow?这也应该有帮助: 如何理解 TensorFlow 中的静态形状和动态形状?
Similar question is nicely explained in TF FAQ : TF FAQ 中很好地解释了类似的问题:
In TensorFlow, a tensor has both a static (inferred) shape and a dynamic (true) shape.
在 TensorFlow 中,张量具有静态(推断)形状和动态(真实)形状。 The static shape can be read using the
tf.Tensor.get_shape
method: this shape is inferred from the operations that were used to create the tensor, and may be partially complete.可以使用
tf.Tensor.get_shape
方法读取静态形状:此形状是从用于创建张量的操作中推断出来的,可能是部分完整的。 If the static shape is not fully defined, the dynamic shape of a Tensor t can be determined by evaluatingtf.shape(t)
.如果静态形状没有完全定义,张量 t 的动态形状可以通过评估
tf.shape(t)
来确定。
So tf.shape()
returns you a tensor, will always have a size of shape=(N,)
, and can be calculated in a session:所以
tf.shape()
返回一个张量,大小总是shape=(N,)
,并且可以在会话中计算:
a = tf.Variable(tf.zeros(shape=(2, 3, 4)))
with tf.Session() as sess:
print sess.run(tf.shape(a))
On the other hand you can extract the static shape by using x.get_shape().as_list()
and this can be calculated anywhere.另一方面,您可以使用
x.get_shape().as_list()
提取静态形状,这可以在任何地方计算。
Simply, use tensor.shape
to get the static shape :简单地说,使用
tensor.shape
来获得静态形状:
In [102]: a = tf.placeholder(tf.float32, [None, 128])
# returns [None, 128]
In [103]: a.shape.as_list()
Out[103]: [None, 128]
Whereas to get the dynamic shape , use tf.shape()
:而要获得动态形状,请使用
tf.shape()
:
dynamic_shape = tf.shape(a)
You can also get the shape as you'd in NumPy with your_tensor.shape
as in the following example.您还可以像在 NumPy 中一样使用
your_tensor.shape
获取形状,如下例所示。
In [11]: tensr = tf.constant([[1, 2, 3, 4, 5], [2, 3, 4, 5, 6]])
In [12]: tensr.shape
Out[12]: TensorShape([Dimension(2), Dimension(5)])
In [13]: list(tensr.shape)
Out[13]: [Dimension(2), Dimension(5)]
In [16]: print(tensr.shape)
(2, 5)
Also, this example, for tensors which can be eval
uated.此外,这个例子,对于可以
eval
张量。
In [33]: tf.shape(tensr).eval().tolist()
Out[33]: [2, 5]
Tensorflow 2.0 Compatible Answer : Tensorflow 2.x (>= 2.0)
compatible answer for nessuno's solution is shown below: Tensorflow 2.0 兼容答案:
Tensorflow 2.x (>= 2.0)
解决方案的Tensorflow 2.x (>= 2.0)
兼容答案如下所示:
x = tf.compat.v1.get_variable("x", [100])
print(x.get_shape())
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.