簡體   English   中英

如何使用存儲在TFRecords文件中的圖像為Estimator構建input_fn

[英]How do I build an input_fn for Estimator using images stored in a TFRecords file

是否有一個如何為圖像分類模型構建input_fn所需的tf.contrib.learn.Estimator的示例? 我的圖像存儲在多個TFRecords文件中。

使用tf.contrib.learn.read_batch_record_features ,我能夠生成批量的編碼圖像字符串。 但是,我沒有看到將這些字符串轉換為圖像的簡單方法。

在這里你可以使用類似下面的內容來存儲在train.tfrecordstest.tfrecords中的mnistfashion-mnist數據集。

轉換為tfrecords是通過此處的代碼完成的,您需要使用解析器來獲取原始圖像和標簽。

def parser(serialized_example):
  """Parses a single tf.Example into image and label tensors."""
  features = tf.parse_single_example(
      serialized_example,
      features={
          'image_raw': tf.FixedLenFeature([], tf.string),
          'label': tf.FixedLenFeature([], tf.int64),
      })
  image = tf.decode_raw(features['image_raw'], tf.uint8)
  image.set_shape([28 * 28])

  # Normalize the values of the image from the range [0, 255] to [-0.5, 0.5]
  image = tf.cast(image, tf.float32) / 255 - 0.5
  label = tf.cast(features['label'], tf.int32)
  return image, label

在使用解析器后,其余部分很簡單,您只需要調用TFRecordDataset(train_filenames) ,然后將解析器函數映射到每個元素,這樣您就可以獲得圖像和標簽作為輸出。

# Keep list of filenames, so you can input directory of tfrecords easily
training_filenames = ["data/train.tfrecords"]
test_filenames = ["data/test.tfrecords"]

# Define the input function for training
def train_input_fn():
  # Import MNIST data
  dataset = tf.contrib.data.TFRecordDataset(train_filenames)

  # Map the parser over dataset, and batch results by up to batch_size
  dataset = dataset.map(parser, num_threads=1, output_buffer_size=batch_size)
  dataset = dataset.batch(batch_size)
  dataset = dataset.repeat()
  iterator = dataset.make_one_shot_iterator()

  features, labels = iterator.get_next()

  return features, labels

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM