簡體   English   中英

如何在 tf.data.Dataset.from_generator 中保留字典鍵?

[英]How to preserve dict keys in tf.data.Dataset.from_generator?

我有這樣的非矩形數據:

samples_train = {'data': [np.array([[1,1]]), np.array([[1,1],[2,2]]), np.array([[1,1],[2,2],[3,3]])],          
                 'labels': [1,2,3]}

它是一個包含 arrays 列表的字典,其中shape=[variable, 2]

由於我有一個自定義訓練循環,我想通過鍵“數據”和“標簽”訪問數據(我有存儲的其他鍵),因此是 dict 格式。

我特別不想將它們填充到一個常見的序列長度(到目前為止,我確實填充了它們,並且上面的from_tensor_slices方法適用於填充的相同長度的序列)。 但現在我需要它們而不是填充。

如果我嘗試:

ds = tf.data.Dataset.from_tensor_slices(samples_train)

我得到這個錯誤,這在某種程度上是有道理的:

ValueError:無法將非矩形 Python 序列轉換為張量。

所以這個問題的答案建議如下:

ds = tf.data.Dataset.from_generator(
    lambda: iter(zip(samples_train['data'], samples_train['labels'])), 
    output_types=(tf.float32, tf.float32)
)

通過檢查可以正常工作:

for batch in ds:
    print(batch)

--> output:

(<tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[1., 1.]], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=1.0>)
(<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[1., 1.],
       [2., 2.]], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=2.0>)
(<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[1., 1.],
       [2., 2.],
       [3., 3.]], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=3.0>)

但是這樣一來,我就松開了我的 dict 鍵。

但是,我希望能夠像這樣訪問它們:

for batch in ds:
    print(batch['data'])
    print(batch['labels'])

如何在數據集中保留這些 dict 鍵?

您可以編寫一個生成器 function 產生一個字典,如下所示:

def my_generator(my_dict):
    for data in zip(*[my_dict[key] for key in my_dict]):
        yield {key:d for key,d in zip(my_dict.keys(), data)}

並在from_generator function 中設置正確的output_types

結果是

>>> ds = tf.data.Dataset.from_generator(
    lambda: my_generator(samples_train),
    output_types={"data": tf.float32, "labels": tf.float32})  
>>> for batch in ds:
      print(batch['data'])
      print(batch['labels'])
tf.Tensor([[1. 1.]], shape=(1, 2), dtype=float32)
tf.Tensor(1.0, shape=(), dtype=float32)
tf.Tensor(
[[1. 1.]
 [2. 2.]], shape=(2, 2), dtype=float32)
tf.Tensor(2.0, shape=(), dtype=float32)
tf.Tensor(
[[1. 1.]
 [2. 2.]
 [3. 3.]], shape=(3, 2), dtype=float32)
tf.Tensor(3.0, shape=(), dtype=float32)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM