簡體   English   中英

如何讀寫二維數組的tfrecord文件

[英]How to read and write tfrecord files of 2d array

我想將一個大小為 (n, 3) 的二維數組制作成一個tfrecord file ,然后讀取它。

我編寫的制作tfrecord file的代碼是

def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))
example = tf.train.Example(
      features=tf.train.Features(
          feature={
              'arry_x':_float_feature(array[:,0]),
              'arry_y':_float_feature(array[:,1]),
              'arry_z':_float_feature(array[:,2])}
         )
      )

with tf.compat.v1.python_io.TFRecordWriter(file_name) as writer:
    writer.write(example.SerializeToString())

我試圖用TFRecordReader讀取文件

def get_tfrecord_feature():
    return{
        'arry_x': tf.compat.v1.io.FixedLenFeature([], tf.float32),
        'arry_y': tf.compat.v1.io.FixedLenFeature([], tf.float32),
        'arry_z': tf.compat.v1.io.FixedLenFeature([], tf.float32)
    }
filenames = [file_name, file_name2, ...]
file_name_queue = tf.train.string_input_producer(filenames)

reader = tf.TFRecordReader()
_, serialized_example = reader.read(file_name_queue)

data = tf.compat.v1.io.parse_single_example(serialized_example, features=get_tfrecord_feature())

x = data['arry_x']
y = data['arry_y']
z = data['arry_z']

x, y, z = tf.train.batch([x, y, z], batch_size=1)

我使用 tf.Session 來檢查代碼

with tf.compat.v1.Session() as sess:
    print(sess.run(x))

代碼運行沒有錯誤,但 session 不打印任何值。 我認為讀取tfrecord file的方式是錯誤的。 有人可以幫我嗎?

我認為您應該在解析 tf 記錄時將列表長度(在您的情況下為array.shape[0]如下)添加到特征定義中。

def get_tfrecord_feature():
    return{
        'arry_x': tf.compat.v1.io.FixedLenFeature([array.shape[0]], tf.float32),
        'arry_y': tf.compat.v1.io.FixedLenFeature([array.shape[0]], tf.float32),
        'arry_z': tf.compat.v1.io.FixedLenFeature([array.shape[0]], tf.float32)
    }

如果 FixedLenFeature 只有一個元素,您可以將形狀保留為 []。 https://tensorflow.org/versions/r1.15/api_docs/python/tf/io/FixedLenFeature

感謝donglinjy的建議,我在這里修復了我的代碼

def get_tfrecord_feature():
    return{
        'arry_x': tf.compat.v1.io.FixedLenFeature([array.shape[0]], tf.float32),
        'arry_y': tf.compat.v1.io.FixedLenFeature([array.shape[0]], tf.float32),
        'arry_z': tf.compat.v1.io.FixedLenFeature([array.shape[0]], tf.float32)
    }

和這里。

with tf.compat.v1.Session() as sess:
    coord=tf.train.Coordinator()
    threads=tf.train.start_queue_runners(coord=coord)
    print(sess.run(x))

現在可以了。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM