[英]How to preserve dict keys in tf.data.Dataset.from_generator?
我有这样的非矩形数据:
samples_train = {'data': [np.array([[1,1]]), np.array([[1,1],[2,2]]), np.array([[1,1],[2,2],[3,3]])],
'labels': [1,2,3]}
它是一个包含 arrays 列表的字典,其中shape=[variable, 2]
。
由于我有一个自定义训练循环,我想通过键“数据”和“标签”访问数据(我有存储的其他键),因此是 dict 格式。
我特别不想将它们填充到一个常见的序列长度(到目前为止,我确实填充了它们,并且上面的from_tensor_slices
方法适用于填充的相同长度的序列)。 但现在我需要它们而不是填充。
如果我尝试:
ds = tf.data.Dataset.from_tensor_slices(samples_train)
我得到这个错误,这在某种程度上是有道理的:
ValueError:无法将非矩形 Python 序列转换为张量。
所以这个问题的答案建议如下:
ds = tf.data.Dataset.from_generator(
lambda: iter(zip(samples_train['data'], samples_train['labels'])),
output_types=(tf.float32, tf.float32)
)
通过检查可以正常工作:
for batch in ds:
print(batch)
--> output:
(<tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[1., 1.]], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=1.0>)
(<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[1., 1.],
[2., 2.]], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=2.0>)
(<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[1., 1.],
[2., 2.],
[3., 3.]], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=3.0>)
但是这样一来,我就松开了我的 dict 键。
但是,我希望能够像这样访问它们:
for batch in ds:
print(batch['data'])
print(batch['labels'])
如何在数据集中保留这些 dict 键?
您可以编写一个生成器 function 产生一个字典,如下所示:
def my_generator(my_dict):
for data in zip(*[my_dict[key] for key in my_dict]):
yield {key:d for key,d in zip(my_dict.keys(), data)}
并在from_generator
function 中设置正确的output_types
。
结果是
>>> ds = tf.data.Dataset.from_generator(
lambda: my_generator(samples_train),
output_types={"data": tf.float32, "labels": tf.float32})
>>> for batch in ds:
print(batch['data'])
print(batch['labels'])
tf.Tensor([[1. 1.]], shape=(1, 2), dtype=float32)
tf.Tensor(1.0, shape=(), dtype=float32)
tf.Tensor(
[[1. 1.]
[2. 2.]], shape=(2, 2), dtype=float32)
tf.Tensor(2.0, shape=(), dtype=float32)
tf.Tensor(
[[1. 1.]
[2. 2.]
[3. 3.]], shape=(3, 2), dtype=float32)
tf.Tensor(3.0, shape=(), dtype=float32)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.