简体   繁体   中英

How could I read image from a directory as input and output while traing a CNN model in Tensorflow?

I want to use CNN to solve the deblurring task, and I have training data that is a directory of png images and a corresponding text file containing the files name.

As the data is too large to add to the memory with one step, and is there any API or some method to make it possible that I could read the blury image as input and its ground-truth as expected result to train?

I have spent quite a few time to solve this, but I got confused after read the APIs in the online API introductions.

The method is not that confused. The tensorflow provides TFrecords file to make good use of the memory.

def create_cord():

    writer = tf.python_io.TFRecordWriter("train.tfrecords")
    for index in xrange(66742):
        blur_file_name = get_file_name(index, True)
        orig_file_name = get_file_name(index, False)
        blur_image_path = cwd + blur_file_name
        orig_image_path = cwd + orig_file_name

        blur_image = Image.open(blur_image_path)
        orig_image = Image.open(orig_image_path)

        blur_image = blur_image.resize((IMAGE_HEIGH, IMAGE_WIDTH))
        orig_image = orig_image.resize((IMAGE_HEIGH, IMAGE_WIDTH))

        blur_image_raw = blur_image.tobytes()
        orig_image_raw = orig_image.tobytes()
        example = tf.train.Example(features=tf.train.Features(feature={
        "blur_image_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[blur_image_raw])),
        'orig_image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[orig_image_raw]))
    }))
    writer.write(example.SerializeToString())
    writer.close()

to read the dataset:

def read_and_decode(filename):
    filename_queue = tf.train.string_input_producer([filename])

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example,
                                   features={
                                       'blur_image_raw':    tf.FixedLenFeature([], tf.string),
                                       'orig_image_raw': tf.FixedLenFeature([], tf.string),
                                   })

    blur_img = tf.decode_raw(features['blur_image_raw'], tf.uint8)
    blur_img = tf.reshape(blur_img, [IMAGE_WIDTH, IMAGE_HEIGH, 3])
    blur_img = tf.cast(blur_img, tf.float32) * (1. / 255) - 0.5

    orig_img = tf.decode_raw(features['blur_image_raw'], tf.uint8)
    orig_img = tf.reshape(orig_img, [IMAGE_WIDTH, IMAGE_HEIGH, 3])
    orig_img = tf.cast(orig_img, tf.float32) * (1. / 255) - 0.5

    return blur_img, orig_img


if __name__ == '__main__':

    #  create_cord()

    blur, orig = read_and_decode("train.tfrecords")
    blur_batch, orig_batch = tf.train.shuffle_batch([blur, orig],
                                                batch_size=3, capacity=1000,
                                                min_after_dequeue=100)
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
     # 启动队列
        threads = tf.train.start_queue_runners(sess=sess)
        for i in range(3):
            v, l = sess.run([blur_batch, orig_batch])
            print(v.shape, l.shape)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM