[英]How to preserve dict keys in tf.data.Dataset.from_generator?
我有這樣的非矩形數據:
samples_train = {'data': [np.array([[1,1]]), np.array([[1,1],[2,2]]), np.array([[1,1],[2,2],[3,3]])],
'labels': [1,2,3]}
它是一個包含 arrays 列表的字典,其中shape=[variable, 2]
。
由於我有一個自定義訓練循環,我想通過鍵“數據”和“標簽”訪問數據(我有存儲的其他鍵),因此是 dict 格式。
我特別不想將它們填充到一個常見的序列長度(到目前為止,我確實填充了它們,並且上面的from_tensor_slices
方法適用於填充的相同長度的序列)。 但現在我需要它們而不是填充。
如果我嘗試:
ds = tf.data.Dataset.from_tensor_slices(samples_train)
我得到這個錯誤,這在某種程度上是有道理的:
ValueError:無法將非矩形 Python 序列轉換為張量。
所以這個問題的答案建議如下:
ds = tf.data.Dataset.from_generator(
lambda: iter(zip(samples_train['data'], samples_train['labels'])),
output_types=(tf.float32, tf.float32)
)
通過檢查可以正常工作:
for batch in ds:
print(batch)
--> output:
(<tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[1., 1.]], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=1.0>)
(<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[1., 1.],
[2., 2.]], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=2.0>)
(<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[1., 1.],
[2., 2.],
[3., 3.]], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=3.0>)
但是這樣一來,我就松開了我的 dict 鍵。
但是,我希望能夠像這樣訪問它們:
for batch in ds:
print(batch['data'])
print(batch['labels'])
如何在數據集中保留這些 dict 鍵?
您可以編寫一個生成器 function 產生一個字典,如下所示:
def my_generator(my_dict):
for data in zip(*[my_dict[key] for key in my_dict]):
yield {key:d for key,d in zip(my_dict.keys(), data)}
並在from_generator
function 中設置正確的output_types
。
結果是
>>> ds = tf.data.Dataset.from_generator(
lambda: my_generator(samples_train),
output_types={"data": tf.float32, "labels": tf.float32})
>>> for batch in ds:
print(batch['data'])
print(batch['labels'])
tf.Tensor([[1. 1.]], shape=(1, 2), dtype=float32)
tf.Tensor(1.0, shape=(), dtype=float32)
tf.Tensor(
[[1. 1.]
[2. 2.]], shape=(2, 2), dtype=float32)
tf.Tensor(2.0, shape=(), dtype=float32)
tf.Tensor(
[[1. 1.]
[2. 2.]
[3. 3.]], shape=(3, 2), dtype=float32)
tf.Tensor(3.0, shape=(), dtype=float32)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.