簡體   English   中英

如何保存 ParallelMapDataset?

[英]How to save ParallelMapDataset?

我有一個輸入數據集(我們將其命名為 ds),一個 function 傳遞給編碼器(名為embedder的模型)。 我想制作一個編碼數據集並將其保存到文件中。 我試圖做什么:

轉換器 function:

def generate_embedding(image, label, embedder):
  return (embedder(image)[0], label)

轉換:

embedding_ds = ds.map(lambda image, label: generate_embedding(image, label, embedder), num_parallel_calls=tf.data.AUTOTUNE)

保存:

embedding_ds.save(path)

但我對embedding_ds有疑問,它不是tf.data.Dataset (我期望的),而是tf.raw_ops.ParallelMapDataset ,它沒有保存方法。 有人可以給個建議嗎?


看起來這個問題出現在我的 tensorflow 版本 (2.9.2) 而不是出現在 2.11

也許更新? 在 2.11.0 中,它有效:

import tensorflow as tf

ds = tf.data.Dataset.range(5)

tf.__version__ # 2.11.0

ds = ds.map(lambda e : (e + 3) % 5, num_parallel_calls=3)

ds.save('test') # works

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM