繁体   English   中英

如何将 Tensorflow 数据集保存到文件中?

[英]How do you save a Tensorflow dataset to a file?

SO上至少还有两个这样的问题,但没有一个得到回答。

我有一个形式的数据集:

<TensorSliceDataset shapes: ((512,), (512,), (512,), ()), types: (tf.int32, tf.int32, tf.int32, tf.int32)>

和另一种形式:

<BatchDataset shapes: ((None, 512), (None, 512), (None, 512), (None,)), types: (tf.int32, tf.int32, tf.int32, tf.int32)>

我看了又看,但找不到将这些数据集保存到以后可以加载的文件的代码。 我得到的最接近的是 TensorFlow 文档中的这个页面,它建议使用tf.data.experimental.TFRecordWriter序列化张量,然后使用tf.io.serialize_tensor将它们写入文件。

但是,当我使用代码尝试此操作时:

dataset.map(tf.io.serialize_tensor)
writer = tf.data.experimental.TFRecordWriter('mydata.tfrecord')
writer.write(dataset)

我在第一行得到一个错误:

TypeError:serialize_tensor() 从 1 到 2 个位置 arguments 但给出了 4 个

如何修改上述内容(或做其他事情)以实现我的目标?

GitHUb 上出现了一个事件,并且似乎 TF 2.3 中有一个新功能可用于写入磁盘:

https://www.tensorflow.org/api_docs/python/tf/data/experimental/save https://www.tensorflow.org/api_docs/python/tf/data/experimental/load

我还没有测试过这个功能,但它似乎正在做你想做的事。

TFRecordWriter似乎是最方便的选择,但不幸的是,它只能编写每个元素只有一个张量的数据集。 您可以使用以下几种解决方法。 首先,由于所有张量都具有相同的类型和相似的形状,因此您可以将它们全部连接成一个,然后在加载时将它们拆分回来:

import tensorflow as tf

# Write
a = tf.zeros((100, 512), tf.int32)
ds = tf.data.Dataset.from_tensor_slices((a, a, a, a[:, 0]))
print(ds)
# <TensorSliceDataset shapes: ((512,), (512,), (512,), ()), types: (tf.int32, tf.int32, tf.int32, tf.int32)>
def write_map_fn(x1, x2, x3, x4):
    return tf.io.serialize_tensor(tf.concat([x1, x2, x3, tf.expand_dims(x4, -1)], -1))
ds = ds.map(write_map_fn)
writer = tf.data.experimental.TFRecordWriter('mydata.tfrecord')
writer.write(ds)

# Read
def read_map_fn(x):
    xp = tf.io.parse_tensor(x, tf.int32)
    # Optionally set shape
    xp.set_shape([1537])  # Do `xp.set_shape([None, 1537])` if using batches
    # Use `x[:, :512], ...` if using batches
    return xp[:512], xp[512:1024], xp[1024:1536], xp[-1]
ds = tf.data.TFRecordDataset('mydata.tfrecord').map(read_map_fn)
print(ds)
# <MapDataset shapes: ((512,), (512,), (512,), ()), types: (tf.int32, tf.int32, tf.int32, tf.int32)>

但是,更一般地说,您可以简单地为每个张量创建一个单独的文件,然后将它们全部读取:

import tensorflow as tf

# Write
a = tf.zeros((100, 512), tf.int32)
ds = tf.data.Dataset.from_tensor_slices((a, a, a, a[:, 0]))
for i, _ in enumerate(ds.element_spec):
    ds_i = ds.map(lambda *args: args[i]).map(tf.io.serialize_tensor)
    writer = tf.data.experimental.TFRecordWriter(f'mydata.{i}.tfrecord')
    writer.write(ds_i)

# Read
NUM_PARTS = 4
parts = []
def read_map_fn(x):
    return tf.io.parse_tensor(x, tf.int32)
for i in range(NUM_PARTS):
    parts.append(tf.data.TFRecordDataset(f'mydata.{i}.tfrecord').map(read_map_fn))
ds = tf.data.Dataset.zip(tuple(parts))
print(ds)
# <ZipDataset shapes: (<unknown>, <unknown>, <unknown>, <unknown>), types: (tf.int32, tf.int32, tf.int32, tf.int32)>

可以将整个数据集放在一个文件中,每个元素有多个单独的张量,即作为包含tf.train.Example的 TFRecords 文件,但我不知道是否有办法在 TensorFlow 中创建它们,也就是说,无需将数据从数据集中取出到 Python 中,然后将其写入记录文件。

要添加 Yoan 的答案:

tf.experimental.save() 和 load() API 运行良好。 您还需要手动将 ds.element_spec 保存到磁盘,以便稍后/在不同的上下文中加载()。

酸洗对我很有效:

1- 保存:

tf.data.experimental.save(
    ds, tf_data_path, compression='GZIP'
)
with open(tf_data_path + '/element_spec', 'wb') as out_:  # also save the element_spec to disk for future loading
    pickle.dump(ds.element_spec, out_)

2-对于加载,您需要包含 tf 分片的文件夹路径和我们手动腌制的 element_spec

with open(tf_data_path + '/element_spec', 'rb') as in_:
    es = pickle.load(in_)

loaded = tf.data.experimental.load(
    tf_data_path, es, compression='GZIP'
)

我也一直在研究这个问题,到目前为止,我已经编写了以下实用程序(也可以在我的仓库中找到)

def cache_with_tf_record(filename: Union[str, pathlib.Path]) -> Callable[[tf.data.Dataset], tf.data.TFRecordDataset]:
    """
    Similar to tf.data.Dataset.cache but writes a tf record file instead. Compared to base .cache method, it also insures that the whole
    dataset is cached
    """

    def _cache(dataset):
        if not isinstance(dataset.element_spec, dict):
            raise ValueError(f"dataset.element_spec should be a dict but is {type(dataset.element_spec)} instead")
        Path(filename).parent.mkdir(parents=True, exist_ok=True)
        with tf.io.TFRecordWriter(str(filename)) as writer:
            for sample in dataset.map(transform(**{name: tf.io.serialize_tensor for name in dataset.element_spec.keys()})):
                writer.write(
                    tf.train.Example(
                        features=tf.train.Features(
                            feature={
                                key: tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.numpy()]))
                                for key, value in sample.items()
                            }
                        )
                    ).SerializeToString()
                )
        return (
            tf.data.TFRecordDataset(str(filename), num_parallel_reads=tf.data.experimental.AUTOTUNE)
            .map(
                partial(
                    tf.io.parse_single_example,
                    features={name: tf.io.FixedLenFeature((), tf.string) for name in dataset.element_spec.keys()},
                ),
                num_parallel_calls=tf.data.experimental.AUTOTUNE,
            )
            .map(
                transform(
                    **{name: partial(tf.io.parse_tensor, out_type=spec.dtype) for name, spec in dataset.element_spec.items()}
                )
            )
            .map(
                transform(**{name: partial(tf.ensure_shape, shape=spec.shape) for name, spec in dataset.element_spec.items()})
            )
        )

    return _cache

有了这个工具,我可以做到:

dataset.apply(cache_with_tf_record("filename")).map(...)

并且还直接加载数据集以供以后仅使用 util 的第二部分使用。

我仍在研究它,因此它可能会在以后更改,特别是使用正确的类型而不是所有字节进行序列化以节省空间(我猜)。

您可以像这样使用tf.data.experimental.savetf.data.experimental.load

保存它的代码:

tf_dataset = get_dataset()    # returns a tf.data.Dataset() file
tf.data.experimental.save(dataset=tf_dataset, path="path/to/desired/save/file_name")
with open("path/to/desired/save/file_name" + ".pickle")), 'wb') as file:
    pickle.dump(tf_dataset.element_spec, file)   # I need this for opening it later

打开代码:

element_spec = pickle.load("path/to/desired/save/file_name" + ".pickle", 'rb'))
tensor_data = tf.data.experimental.load("path/to/desired/save/file_name", element_spec=element_spec)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM