簡體   English   中英

Tensorflow CNN圖像增強管道

[英]Tensorflow CNN image augmentation pipeline

我正在嘗試學習新的Tensorflow API,但對於在哪里獲取輸入批處理張量的句柄,我有點迷失了,因此我可以使用tf.image等操作和擴充它們。

這是我當前的網絡和管道:

trainX, testX, trainY, testY = read_data()
# trainX [num_image, height, width, channels], these are numpy arrays

#...
train_dataset = tf.data.Dataset.from_tensor_slices((trainX, trainY))
test_dataset = tf.data.Dataset.from_tensor_slices((testX, testY))

#...
iterator = tf.data.Iterator.from_structure(train_dataset.output_types, 
                 train_dataset.output_shapes)
features, labels = iterator.get_next()
train_init_op = iterator.make_initializer(train_dataset)
test_init_op = iterator.make_initializer(test_dataset)

#...defining cnn architecture...

# In the train loop
TrainLoop {
   sess.run(train_init_op)  # switching to train data
   sess.run(train_step, ...) # running a train step

   #... 
   sess.run(test_init_op)  # switching to test data
   test_loss = sess.run(loss, ...) # printing test loss after epoch
}

我正在使用Dataset API創建2個數據集,以便在trainloop中可以計算火車並測試損失並記錄它們。

我將在此管道中的哪個位置處理和扭曲輸入的圖像? 我沒有為trainX輸入批次創建任何tf.placeholders,因此我無法使用tf.image來操縱它們,因為例如tf.image.flip_up_down需要3D或4D張量。

  • 用新API實施此管道的自然方法是什么?
  • 有沒有一種模塊或簡單的方法來增加輸入的圖像批處理量,以適合該管道?

最近發布了一篇非常不錯的文章演講 ,比我在這里的回復更詳細地介紹了API。 這是一個簡單的示例:

import tensorflow as tf
import numpy as np


def read_data():
    n_train = 100
    n_test = 50
    height = 20
    width = 30
    channels = 3
    trainX = (np.random.random(
        size=(n_train, height, width, channels)) * 255).astype(np.uint8)
    testX = (np.random.random(
            size=(n_test, height, width, channels))*255).astype(np.uint8)
    trainY = (np.random.random(size=(n_train,))*10).astype(np.int32)
    testY = (np.random.random(size=(n_test,))*10).astype(np.int32)
    return trainX, testX, trainY, testY


trainX, testX, trainY, testY = read_data()
# trainX [num_image, height, width, channels], these are numpy arrays


train_dataset = tf.data.Dataset.from_tensor_slices((trainX, trainY))
test_dataset = tf.data.Dataset.from_tensor_slices((testX, testY))


def map_single(x, y):
    print('Map single:')
    print('x shape: %s' % str(x.shape))
    print('y shape: %s' % str(y.shape))
    x = tf.image.per_image_standardization(x)
    # Consider: x = tf.image.random_flip_left_right(x)
    return x, y


def map_batch(x, y):
    print('Map batch:')
    print('x shape: %s' % str(x.shape))
    print('y shape: %s' % str(y.shape))
    # Note: this flips ALL images left to right. Not sure this is what you want
    # UPDATE: looks like tf documentation is wrong and you need a 3D tensor?
    # return tf.image.flip_left_right(x), y
    return x, y


batch_size = 32
train_dataset = train_dataset.repeat().shuffle(100)
train_dataset = train_dataset.map(map_single, num_parallel_calls=8)
train_dataset = train_dataset.batch(batch_size)
train_dataset = train_dataset.map(map_batch)
train_dataset = train_dataset.prefetch(2)

test_dataset = test_dataset.map(
        map_single, num_parallel_calls=8).batch(batch_size).map(map_batch)
test_dataset = test_dataset.prefetch(2)


iterator = tf.data.Iterator.from_structure(train_dataset.output_types, 
                 train_dataset.output_shapes)
features, labels = iterator.get_next()
train_init_op = iterator.make_initializer(train_dataset)
test_init_op = iterator.make_initializer(test_dataset)


with tf.Session() as sess:
    sess.run(train_init_op)
    feat, lab = sess.run((features, labels))

    print(feat.shape)
    print(lab.shape)

    sess.run(test_init_op)
    feat, lab = sess.run((features, labels))

    print(feat.shape)
    print(lab.shape)    

一些注意事項:

  1. 這種方法依賴於能夠將整個數據集加載到內存中。 如果不能,請考慮使用tf.data.Dataset.from_generator 如果您的混洗緩沖區很大,這可能導致緩慢的混洗時間。 我的首選方法是將某些keys張量完全加載到內存中-它可能只是每個示例的索引-然后使用tf.py_func將鍵值map到數據值。 這比轉換為tfrecords效率tfrecords ,但是通過prefetching它可能不會影響性能。 由於改組是在映射之前完成的,因此您只需要將shuffle_buffer鍵加載到內存中,而不是shuffle_buffer示例。
  2. 要擴充數據集,請在批處理操作之前或之后使用tf.data.Dataset.map ,這取決於您是否要應用批量操作(在4D圖像張量上起作用)或元素級操作( 3D圖像張量)。 請注意, tf.image.flip_left_right的文檔tf.image.flip_left_right已過時,因為在嘗試使用4D張量時出現錯誤。 如果要隨機擴充數據,請使用tf.image.random_flip_left_right而不是tf.image.flip_left_right
  3. 如果您使用的是tf.estimator.Estimator (或不介意將代碼轉換為使用它),請查看tf.estimator.train_and_evaluate ,以了解tf.estimator.train_and_evaluate在數據集之間進行切換的內置方法。
  4. 考慮洗牌/與重復數據集shuffle / repeat的方法。 有關效率的說明,請參見該文章 特別是,對於大多數應用程序,重復->隨機播放->映射->批處理->逐批映射->預取似乎是最佳的操作順序。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM