簡體   English   中英

無法從Tensorflow tfrecord文件讀取

[英]Unable to read from Tensorflow tfrecord file

我可以使用以下代碼創建tfrecords文件。

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def convert_to_tfrecord(images,labels,file_name):
    # images is a numpy array of shape (num_images,channel,rows,column)
    # labels is a numpy array of shape (num_images,)
    num_labels = np.shape(labels)
    (num_images,depth,rows,cols) = np.shape(images)
    writer = tf.python_io.TFRecordWriter(file_name)
    for index in range(num_images):
        image_raw = images[index]
        image_raw = image_raw.astype(np.float32)
        image_raw = image_raw.tostring()
        example = tf.train.Example(features=tf.train.Features(feature={
            'height': _int64_feature(rows),
            'width': _int64_feature(cols),
            'depth': _int64_feature(depth),
            'label': _int64_feature(int(labels[index])),
            'image_raw': _bytes_feature(image_raw)}))

        writer.write(example.SerializeToString())
    writer.close()

但是,在使用以下功能從tfrecord文件讀取數據時

def read_and_decode(filename_queue):
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    img_features = tf.parse_single_example(
       serialized_example,
       features={
          'height': tf.FixedLenFeature([], tf.int64),
          'width': tf.FixedLenFeature([], tf.int64),
          'depth': tf.FixedLenFeature([], tf.int64),
          'image_raw': tf.FixedLenFeature([], tf.string),
          'label': tf.FixedLenFeature([], tf.int64),
      })

    image = tf.decode_raw(img_features['image_raw'], tf.float32)
    label = tf.cast(img_features['label'], tf.int32)
    height = tf.cast(img_features['height'], tf.int32)
    width = tf.cast(img_features['width'], tf.int32)
    depth = tf.cast(img_features['depth'], tf.int32)
    image_shape = tf.stack([depth,height, width])
    image = tf.reshape(image, image_shape)
    return image,label

def inputs(batch_size, num_epochs):
    filename = ['set1.tfrecords']
    # dir_path is a global variable
    file_path = dir_path + 'set1.tfrecords'
    filename_queue = tf.train.string_input_producer([file_path], num_epochs=1)
    image,label = read_and_decode(filename_queue)
    images, sparse_labels = tf.train.shuffle_batch(
        [image, label], batch_size=batch_size, num_threads=2,
       capacity=1000 + 3 * batch_size, min_after_dequeue=1000)
    return images, sparse_labels

我不斷收到以下錯誤

 images, labels = tf.train.shuffle_batch([image, label], batch_size=10, capacity=30, num_threads=1, min_after_dequeue=10) 

File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/input.py", line 1225, in shuffle_batch
name=name)

File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/input.py", line 781, in _shuffle_batch
dtypes=types, shapes=shapes, shared_name=shared_name)

File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/data_flow_ops.py", line 641, in __init__
shapes = _as_shape_list(shapes, dtypes)

File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/data_flow_ops.py", line 77, in _as_shape_list
raise ValueError("All shapes must be fully defined: %s" % shapes)

ValueError: All shapes must be fully defined: [TensorShape([Dimension(None)]), TensorShape([])]

出現上述錯誤的原因是什么?如何解決這個問題? 我可以使用tf.python_io.tf_record_iterator(path=filename)遍歷文件來讀取tfrecords文件。

出現此錯誤是因為tf.train.shuffle_batch需要知道張量的形狀才能進行批處理(批處理中的項目必須具有相同的形狀)。 但是,原則上,原始數據可以具有不同的大小,因此tf.decode_raw不會為張量設置任何形狀。

在評論中,您提到所有圖像都具有形狀(192,81,2) ,因此在從read_and_decode返回之前,只需要在圖像張量中設置該形狀read_and_decode

def read_and_decode(filename_queue):
    # rest of your code here
    image_shape = [height, width, depth]
    image = tf.reshape(image, image_shape)
    image.set_shape(image_shape) #<<<<<<<<<<<<<<<
    return image,label

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM