簡體   English   中英

如何在 Tensorflow 2 中解碼示例(從 1.12 移植)

[英]How to decode examples in Tensorflow 2 (porting from 1.12)

我有以下方法應該從序列化的TFRecordDataset解碼樣本:

def decode_example(self, serialized_example):
    """Return a dict of Tensors from a serialized tensorflow.Example."""
    data_fields, data_items_to_decoders = self.example_reading_spec()
    # Necessary to rejoin examples in the correct order with the Cloud ML Engine
    # batch prediction API.
    data_fields['batch_prediction_key'] = tf.io.FixedLenFeature([1], tf.int64, 0)
    if data_items_to_decoders is None:
        data_items_to_decoders = {
            field: tf.contrib.slim.tfexample_decoder.Tensor(field)
            for field in data_fields
        }

    decoder = tf.contrib.slim.tfexample_decoder.TFExampleDecoder(data_fields, data_items_to_decoders)

    decode_items = list(sorted(data_items_to_decoders))
    decoded = decoder.decode(serialized_example, items=decode_items)
    return dict(zip(decode_items, decoded))

但是,這在 Tensorflow 2 下不起作用。

tf.contrib不再存在,我找不到任何可以用來解碼這些示例的東西。

安裝tensorflow-data-validation后,我什至找不到TFExampleDecoder

知道那里出了什么問題和/或我如何解碼我的例子嗎?

我能夠使用tf.io.parse_single_example使其工作。

我們必須像往常一樣聲明我們的數據字段( example_reading_spec ),然后我們可以使用它來解碼一個示例:

def example_reading_spec():

    data_fields = {
        'inputs': tf.io.VarLenFeature(tf.float32),
        'targets': tf.io.VarLenFeature(tf.int64),
    }

    return data_fields

def decode_example(serialized_example):
    """Return a dict of Tensors from a serialized tensorflow.Example."""
    return tf.io.parse_single_example(
        serialized_example,
        features=example_reading_spec()
    )

現在我們可以使用Dataset.map像這樣加載我們的數據集分片:

record_dataset = tf.data.TFRecordDataset(filenames, buffer_size=1024)
record_dataset = record_dataset.map(decode_example)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM