繁体   English   中英

从TFRecordDataset获取数据集为numpy数组

[英]Get data set as numpy array from TFRecordDataset

我正在使用新的tf.data API为CIFAR10数据集创建一个迭代器。 我正在从两个.tfrecord文件中读取数据。 一个保存训练数据(train.tfrecords),另一个保存测试数据(test.tfrecords)。 这很好用。 但是,在某些时候,我需要数据集(训练数据和测试数据)作为numpy数组

是否可以从tf.data.TFRecordDataset对象中检索数据集为numpy数组?

您可以使用tf.data.Dataset.batch()转换和tf.contrib.data.get_single_element()来执行此操作。 作为复习, dataset.batch(n)将占用dataset n个连续元素,并通过连接每个组件将它们转换为一个元素。 这要求所有元素每个组件具有固定的形状。 如果n大于dataset的元素数(或者如果n不精确地划分元素数),则最后一批可以更小。 因此,您可以为n选择较大的值并执行以下操作:

import numpy as np
import tensorflow as tf

# Insert your own code for building `dataset`. For example:
dataset = tf.data.TFRecordDataset(...)  # A dataset of tf.string records.
dataset = dataset.map(...)  # Extract components from each tf.string record.

# Choose a value of `max_elems` that is at least as large as the dataset.
max_elems = np.iinfo(np.int64).max
dataset = dataset.batch(max_elems)

# Extracts the single element of a dataset as one or more `tf.Tensor` objects.
# No iterator needed in this case!
whole_dataset_tensors = tf.contrib.data.get_single_element(dataset)

# Create a session and evaluate `whole_dataset_tensors` to get arrays.
with tf.Session() as sess:
    whole_dataset_arrays = sess.run(whole_dataset_tensors)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM