繁体   English   中英

如何在 tf.data.Dataset.from_generator 中保留字典键?

[英]How to preserve dict keys in tf.data.Dataset.from_generator?

我有这样的非矩形数据:

samples_train = {'data': [np.array([[1,1]]), np.array([[1,1],[2,2]]), np.array([[1,1],[2,2],[3,3]])],          
                 'labels': [1,2,3]}

它是一个包含 arrays 列表的字典,其中shape=[variable, 2]

由于我有一个自定义训练循环,我想通过键“数据”和“标签”访问数据(我有存储的其他键),因此是 dict 格式。

我特别不想将它们填充到一个常见的序列长度(到目前为止,我确实填充了它们,并且上面的from_tensor_slices方法适用于填充的相同长度的序列)。 但现在我需要它们而不是填充。

如果我尝试:

ds = tf.data.Dataset.from_tensor_slices(samples_train)

我得到这个错误,这在某种程度上是有道理的:

ValueError:无法将非矩形 Python 序列转换为张量。

所以这个问题的答案建议如下:

ds = tf.data.Dataset.from_generator(
    lambda: iter(zip(samples_train['data'], samples_train['labels'])), 
    output_types=(tf.float32, tf.float32)
)

通过检查可以正常工作:

for batch in ds:
    print(batch)

--> output:

(<tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[1., 1.]], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=1.0>)
(<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[1., 1.],
       [2., 2.]], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=2.0>)
(<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[1., 1.],
       [2., 2.],
       [3., 3.]], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=3.0>)

但是这样一来,我就松开了我的 dict 键。

但是,我希望能够像这样访问它们:

for batch in ds:
    print(batch['data'])
    print(batch['labels'])

如何在数据集中保留这些 dict 键?

您可以编写一个生成器 function 产生一个字典,如下所示:

def my_generator(my_dict):
    for data in zip(*[my_dict[key] for key in my_dict]):
        yield {key:d for key,d in zip(my_dict.keys(), data)}

并在from_generator function 中设置正确的output_types

结果是

>>> ds = tf.data.Dataset.from_generator(
    lambda: my_generator(samples_train),
    output_types={"data": tf.float32, "labels": tf.float32})  
>>> for batch in ds:
      print(batch['data'])
      print(batch['labels'])
tf.Tensor([[1. 1.]], shape=(1, 2), dtype=float32)
tf.Tensor(1.0, shape=(), dtype=float32)
tf.Tensor(
[[1. 1.]
 [2. 2.]], shape=(2, 2), dtype=float32)
tf.Tensor(2.0, shape=(), dtype=float32)
tf.Tensor(
[[1. 1.]
 [2. 2.]
 [3. 3.]], shape=(3, 2), dtype=float32)
tf.Tensor(3.0, shape=(), dtype=float32)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM