[英]How to save ParallelMapDataset?
我有一個輸入數據集(我們將其命名為 ds),一個 function 傳遞給編碼器(名為embedder
的模型)。 我想制作一個編碼數據集並將其保存到文件中。 我試圖做什么:
轉換器 function:
def generate_embedding(image, label, embedder):
return (embedder(image)[0], label)
轉換:
embedding_ds = ds.map(lambda image, label: generate_embedding(image, label, embedder), num_parallel_calls=tf.data.AUTOTUNE)
保存:
embedding_ds.save(path)
但我對embedding_ds
有疑問,它不是tf.data.Dataset
(我期望的),而是tf.raw_ops.ParallelMapDataset
,它沒有保存方法。 有人可以給個建議嗎?
看起來這個問題出現在我的 tensorflow 版本 (2.9.2) 而不是出現在 2.11
也許更新? 在 2.11.0 中,它有效:
import tensorflow as tf
ds = tf.data.Dataset.range(5)
tf.__version__ # 2.11.0
ds = ds.map(lambda e : (e + 3) % 5, num_parallel_calls=3)
ds.save('test') # works
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.