簡體   English   中英

如何使用DataSet API在Tensorflow中為tf.train.SequenceExample數據創建填充批次?

[英]How do I create padded batches in Tensorflow for tf.train.SequenceExample data using the DataSet API?

為了在Tensorflow中訓練LSTM模型 ,我將數據結構化為tf.train.SequenceExample格式並將其存儲到TFRecord文件中 我現在想使用新的DataSet API來生成用於訓練的填充批次 文檔中有一個使用padded_batch的例子,但對於我的數據,我無法弄清楚padded_shapes應該是什么值。

為了將TFrecord文件讀入批處理,我編寫了以下Python代碼:

import math
import tensorflow as tf
import numpy as np
import struct
import sys
import array

if(len(sys.argv) != 2):
  print "Usage: createbatches.py [RFRecord file]"
  sys.exit(0)


vectorSize = 40
inFile = sys.argv[1]

def parse_function_dataset(example_proto):
  sequence_features = {
      'inputs': tf.FixedLenSequenceFeature(shape=[vectorSize],
                                           dtype=tf.float32),
      'labels': tf.FixedLenSequenceFeature(shape=[],
                                           dtype=tf.int64)}

  _, sequence = tf.parse_single_sequence_example(example_proto, sequence_features=sequence_features)

  length = tf.shape(sequence['inputs'])[0]
  return sequence['inputs'], sequence['labels']

sess = tf.InteractiveSession()

filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = dataset.map(parse_function_dataset)
# dataset = dataset.batch(1)
dataset = dataset.padded_batch(4, padded_shapes=[None])
iterator = dataset.make_initializable_iterator()

batch = iterator.get_next()

# Initialize `iterator` with training data.
training_filenames = [inFile]
sess.run(iterator.initializer, feed_dict={filenames: training_filenames})

print(sess.run(batch))

如果我使用dataset = dataset.batch(1) (在這種情況下不需要填充),代碼效果很好,但是當我使用padded_batch變體時,我收到以下錯誤:

TypeError:如果淺層結構是序列,則輸入也必須是序列。 輸入有類型:。

你能幫我弄清楚我應該為padded_shapes參數傳遞什么嗎?

(我知道有很多使用線程和隊列的示例代碼,但我寧願在這個項目中使用新的DataSet API)

你需要傳遞一個形狀元組。 在你的情況下你應該通過

dataset = dataset.padded_batch(4, padded_shapes=([vectorSize],[None]))

或嘗試

dataset = dataset.padded_batch(4, padded_shapes=([None],[None]))

檢查此代碼以獲取更多詳細信息 我不得不調試這個方法來弄清楚它為什么不適合我。

如果當前的Dataset對象包含元組,則還可以指定每個填充元素的形狀。

例如,我有一個(same_sized_images, Labels)數據集,每個標簽的長度不同但排名相同。

def process_label(resized_img, label):
    # Perfrom some tensor transformations
    # ......

    return resized_img, label

dataset = dataset.map(process_label)
dataset = dataset.padded_batch(batch_size, 
                               padded_shapes=([None, None, 3], 
                                              [None, None]))  # my label has rank 2

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM