简体   繁体   English

如何在 Tensorflow 中 decode_jpeg 后获取图像的形状?

[英]How to get the shape of an image after decode_jpeg in Tensorflow?

I have an image which I have feed into tf.image.decode_jpeg:我有一个图像,我已将其输入 tf.image.decode_jpeg:

img = tf.io.read_file(file_path)
img = tf.image.decode_jpeg(img, channels=3)

and I am trying to get its height and with width img.shape[0] and img.shape[1] , but both return None .我试图获得它的高度和宽度img.shape[0]img.shape[1] ,但都返回None Actually, img.shape returns (None, None, 3) .实际上, img.shape返回(None, None, 3)

I am using this inside a function that is mapped into a tf.data.Dataset .我在映射到tf.data.Dataset的函数中使用它。 How can I get the real shape of the image?如何获得图像的真实形状?

update:更新:

At the moment, I have found a solution that consists in wrapping the code with tf.py_function to execute it eagerly because the dataset creates an internal graph.目前,我找到了一个解决方案,即用tf.py_function包装代码以tf.py_function地执行它,因为数据集创建了一个内部图。 I would appreciate If anyone has another solution to do it in a pure graph way, which would improve performance.如果有人有另一种解决方案以纯图形方式进行处理,我将不胜感激,这将提高性能。

Since you have already found a solution to get the shape of the image by wrapping your code around tf.py_function .由于您已经找到了通过将代码围绕tf.py_function来获取图像形状的解决tf.py_function Providing the solution here for the benefit of the community.在这里提供解决方案以造福社区。

However, since eager execution is enabled by default in TensorFlow 2, you can get the shape directly like mentioned below without having to wrap it around tf.py_function .但是,由于在 TensorFlow 2 中默认启用了tf.py_function ,因此您可以直接获得如下所述的形状,而无需将其包裹在tf.py_function周围。

Tensorflow 1.x: TensorFlow 1.x:

img = tf.io.read_file("sample.jpg")
img = tf.image.decode_jpeg(img, channels=3)

with tf.Session() as sess:
  array = img.eval(session=sess)
  height = array.shape[0]
  width = array.shape[1]
  print("Height:",height)
  print("Width:",width) 

Height:320高度:320

Width:320宽度:320

Tensorflow 2:张量流2:

img = tf.io.read_file("sample.jpg")
img = tf.image.decode_jpeg(img, channels=3)
height = img.shape[0]
width = img.shape[1]

print("Height:",height)
print("Width:",width) 

Height: 320高度:320

Width: 320宽度:320

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM