[英]TypeError: unsupported callable using Dataset with estimator input_fn
[英]How do I build an input_fn for Estimator using images stored in a TFRecords file
是否有一個如何為圖像分類模型構建input_fn
所需的tf.contrib.learn.Estimator
的示例? 我的圖像存儲在多個TFRecords文件中。
使用tf.contrib.learn.read_batch_record_features
,我能夠生成批量的編碼圖像字符串。 但是,我沒有看到將這些字符串轉換為圖像的簡單方法。
在這里你可以使用類似下面的內容來存儲在train.tfrecords
和test.tfrecords
中的mnist
和fashion-mnist
數據集。
轉換為tfrecords
是通過此處的代碼完成的,您需要使用解析器來獲取原始圖像和標簽。
def parser(serialized_example):
"""Parses a single tf.Example into image and label tensors."""
features = tf.parse_single_example(
serialized_example,
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
})
image = tf.decode_raw(features['image_raw'], tf.uint8)
image.set_shape([28 * 28])
# Normalize the values of the image from the range [0, 255] to [-0.5, 0.5]
image = tf.cast(image, tf.float32) / 255 - 0.5
label = tf.cast(features['label'], tf.int32)
return image, label
在使用解析器后,其余部分很簡單,您只需要調用TFRecordDataset(train_filenames)
,然后將解析器函數映射到每個元素,這樣您就可以獲得圖像和標簽作為輸出。
# Keep list of filenames, so you can input directory of tfrecords easily
training_filenames = ["data/train.tfrecords"]
test_filenames = ["data/test.tfrecords"]
# Define the input function for training
def train_input_fn():
# Import MNIST data
dataset = tf.contrib.data.TFRecordDataset(train_filenames)
# Map the parser over dataset, and batch results by up to batch_size
dataset = dataset.map(parser, num_threads=1, output_buffer_size=batch_size)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.