[英]Tensorflow dataset how to get the shape of the generator of data?
考慮從 tensorflow 數據集加載以下數據集
(ds_train, ds_test), ds_info= tfds.load('mnist', split=['train', 'test'],
shuffle_files=True,
as_supervised=True,with_info=True)
不過,該網站稱
#https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_generator
#Warning: SOME ARGUMENTS ARE DEPRECATED: (output_shapes, output_types). They will be removed in a future version.
#Instructions for updating: Use output_signature instead
但沒有一個
ds_train.output_shapes
ds_train.output_types
ds_train.output_signature
正在工作
這里提到了一個類似的問題 # https://github.com/tensorflow/datasets/issues/102 ,所以現在只有臨時修復
shape_of_data=tf.compat.v1.data.get_output_shapes(ds_train)
正在工作,它返回
(TensorShape([None, 28, 28, 1]), TensorShape([None]))
另一個更新的 function 正在工作,但無法將 TensorShape 排除在參數之外
tf.data.DatasetSpec(ds_train)
回來
DatasetSpec(<_OptionsDataset shapes: ((28, 28, 1), ()), types: (tf.uint8, tf.int64)>, TensorShape([]))
無法分配。
獲取生成器/迭代器形狀的更新的 function 或屬性是什么?
可以使用dataset.element_spec
:
import tensorflow_datasets as tfds
(ds_train, ds_test), ds_info = tfds.load(
"mnist",
split=["train", "test"],
shuffle_files=True,
as_supervised=True,
with_info=True,
)
ds_train.element_spec
# (TensorSpec(shape=(28, 28, 1), dtype=tf.uint8, name=None),
# TensorSpec(shape=(), dtype=tf.int64, name=None))
ds_train.element_spec[0].shape
# TensorShape([28, 28, 1])
您的變量ds_info
包含以下信息:
height, width, channels = ds_info.features['image'].shape
像這樣看:
ds_info.features['image']
Image(shape=(28, 28, 1), dtype=tf.uint8)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.