簡體   English   中英

Tensorflow 無法解碼 tfrecords 中的 jpeg 字節

[英]Tensorflow failed to decode jpeg bytes in tfrecords

我試圖將一些圖像寫入 tfrecord 文件,但我發現它太大了。 然后我嘗試將原始 jpeg 字節寫入 tfrecord 文件。 但是當我嘗試閱讀它時,出現異常:ValueError:Shape must be rank 0 but is rank 1 for 'DecodeJpeg' (op: 'DecodeJpeg') with input shapes: [32].

以下是我的代碼

import tensorflow as tf
import os


def write_features(example_image_paths, tf_records_path):
    with tf.python_io.TFRecordWriter(tf_records_path) as writer:
        for image_path in example_image_paths:
            with open(image_path, 'rb') as f:
                image_bytes = f.read()
            feautres = tf.train.Features(
                feautres={
                    'images':
                    tf.train.Feature(bytes_list=tf.train.BytesList(
                        value=image_bytes))
                })
            example = tf.train.Example(feautres)
            writer.write(example.SerializeToString())


def extract_features_batch(serialized_batch):
    """

    :param serialized_batch:
    :return:
    """
    features = tf.parse_example(
        serialized_batch,
        features={'images': tf.FixedLenFeature([], tf.string)})
    bs = features['images'].shape[0]
    images = tf.image.decode_image(features['images'], 3)
    w, h = (280, 32)
    images = tf.cast(x=images, dtype=tf.float32)
    images = tf.reshape(images, [bs, h, w, 3])

    return images


def inputs(tfrecords_path, batch_size, num_epochs, num_threads=4):
    """

    :param tfrecords_path:
    :param batch_size:
    :param num_epochs:
    :param num_threads:
    :return: input_images, input_labels, input_image_names
    """

    if not num_epochs:
        num_epochs = None

    dataset = tf.data.TFRecordDataset(tfrecords_path)

    dataset = dataset.batch(batch_size, drop_remainder=True)

    # The map transformation takes a function and applies it to every element
    # of the dataset.
    dataset = dataset.map(map_func=extract_features_batch,
                          num_parallel_calls=num_threads)
    dataset = dataset.shuffle(buffer_size=1000)
    dataset = dataset.repeat()

    iterator = dataset.make_one_shot_iterator()

    return iterator.get_next(name='IteratorGetNext')


if __name__ == '__main__':
    pass
    # img_names = os.listdir('./images')
    # img_paths = []
    # for img_name in img_paths:
    #     img_paths.append(os.path.join('./images', img_name))
    # write_features(img_paths, 'test.tfrecords')

    images = inputs('./test.tfrecords', 32, None)

如何正確讀取和解碼 jpeg 字節? 謝謝!

您需要在批處理數據集之前解碼圖像。 換句話說,在您的 inputs() function 中,“正確”的順序是:

dataset = dataset.map(map_func=extract_features_batch,
                      num_parallel_calls=num_threads) 

dataset = dataset.batch(batch_size, drop_remainder=True)

The documentation says ( https://www.tensorflow.org/api_docs/python/tf/io/decode_image ) that tf.io.decode_image expects an image in a form of a scalar or 0-dimensional string (0-D string is被認為是一個標量),而如果您首先對數據集 object 進行批處理,則 tf.io.decode_image 接收圖像列表(或批次)(表示為 batch_size 乘以 0 維字符串的列表)。 然后它抱怨它在收到一個形狀為 [32] 的數組時期望 0 維數組(在您的情況下是批量大小)。

我不知道我們如何優化批處理的輸入管道,而不是在處理后低效地進行批處理。 像往常一樣,在 tf 2.0 的文檔中沒有任何內容。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM