简体   繁体   English

Tensorflow 数据集如何获取数据生成器的形状?

[英]Tensorflow dataset how to get the shape of the generator of data?

Consider loading the following dataset from tensorflow datasets考虑从 tensorflow 数据集加载以下数据集

(ds_train, ds_test), ds_info= tfds.load('mnist', split=['train', 'test'],
                                        shuffle_files=True,
                                        as_supervised=True,with_info=True)

However, the website said不过,该网站称

#https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_generator
#Warning: SOME ARGUMENTS ARE DEPRECATED: (output_shapes, output_types). They will be removed in a future version. 
#Instructions for updating: Use output_signature instead

but none of the但没有一个

ds_train.output_shapes
ds_train.output_types
ds_train.output_signature

were working正在工作

A similar issue was mentioned here # https://github.com/tensorflow/datasets/issues/102 , so right now only the temporary fix这里提到了一个类似的问题 # https://github.com/tensorflow/datasets/issues/102 ,所以现在只有临时修复

shape_of_data=tf.compat.v1.data.get_output_shapes(ds_train)

was working, which returned正在工作,它返回

(TensorShape([None, 28, 28, 1]), TensorShape([None]))

Another updated function was working, but one could not get the TensorShape out of the argument另一个更新的 function 正在工作,但无法将 TensorShape 排除在参数之外

tf.data.DatasetSpec(ds_train) 

returned回来

DatasetSpec(<_OptionsDataset shapes: ((28, 28, 1), ()), types: (tf.uint8, tf.int64)>, TensorShape([]))

which could not be assigned.无法分配。

What's the updated function or attributes to get the shape of the generator/iterator?获取生成器/迭代器形状的更新的 function 或属性是什么?

One can use dataset.element_spec :可以使用dataset.element_spec

import tensorflow_datasets as tfds

(ds_train, ds_test), ds_info = tfds.load(
    "mnist",
    split=["train", "test"],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

ds_train.element_spec
# (TensorSpec(shape=(28, 28, 1), dtype=tf.uint8, name=None),
#  TensorSpec(shape=(), dtype=tf.int64, name=None))

ds_train.element_spec[0].shape
# TensorShape([28, 28, 1])

Your variable ds_info has that information:您的变量ds_info包含以下信息:

height, width, channels = ds_info.features['image'].shape

Look at it like this:像这样看:

ds_info.features['image']
Image(shape=(28, 28, 1), dtype=tf.uint8)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM