簡體   English   中英

根據類別對部分訓練集應用不同的數據增強

[英]Apply different data augmentation to part of the train set based on the category

我正在研究機器學習過程來對圖像進行分類。 我的問題是我的數據集是不平衡的,在我的 5 類圖像中,我有大約 400 張圖像在一個類別中,而每個其他類別大約有 20 個圖像。

我想通過僅將數據增強應用於我的訓練集的某些類別來平衡我的訓練集。

這是我用於創建訓練驗證集的代碼:

# Import data
data_dir = pathlib.Path(r"C:\Train set")

# Define train and validation sets (80% - 20%)
batch_size = 32
img_height = 240
img_width = 240

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

這是我應用數據增強的方法,盡管這將適用於整個訓練集:

# Apply data augmentation
data_augmentation = keras.Sequential(
  [
    layers.experimental.preprocessing.RandomFlip("horizontal", 
                                                 input_shape=(img_height, 
                                                              img_width,
                                                              3)),
    layers.experimental.preprocessing.RandomRotation(0.1),
    layers.experimental.preprocessing.RandomZoom(0.1),
  ]
)

有什么辦法可以進入我的訓練集,提取那些圖像較少的類別,然后只對它們應用數據增強?

提前致謝!

我建議不要使用ImageDataGenerator而是使用定制的tf.data.Dataset 在映射操作中,您可以區別對待類別,例如:

def preprocess(filepath):
    category = tf.strings.split(filepath, os.sep)[0]
    read_file = tf.io.read_file(filepath)
    decode = tf.image.decode_jpeg(read_file, channels=3)
    resize = tf.image.resize(decode, (200, 200))
    image = tf.expand_dims(resize, 0)
    if tf.equal(category, 'tf_astronauts'):
        image = tf.image.flip_up_down(image)
        image = tf.image.flip_left_right(image)
    # image = tf.image.convert_image_dtype(image, tf.float32)
    # category = tf.cast(tf.equal(category, 'tf_astronauts'), tf.int32)
    return image, category

讓我演示一下。 讓我們為您創建一個包含訓練圖像的文件夾:

import tensorflow as tf
import matplotlib.pyplot as plt
import cv2
from skimage import data
from glob2 import glob
import os

cat = data.chelsea()
astronaut = data.astronaut()

for category, picture in zip(['tf_cats', 'tf_astronauts'], [cat, astronaut]):
    os.makedirs(category, exist_ok=True)
    for i in range(5):
        cv2.imwrite(os.path.join(category, category + f'_{i}.jpg'),
                    cv2.cvtColor(picture, cv2.COLOR_RGB2BGR))

files = glob('tf_*\\*.jpg')

現在你有這些文件:

['tf_astronauts\\tf_astronauts_0.jpg',
 'tf_astronauts\\tf_astronauts_1.jpg',
 'tf_astronauts\\tf_astronauts_2.jpg',
 'tf_astronauts\\tf_astronauts_3.jpg',
 'tf_astronauts\\tf_astronauts_4.jpg',
 'tf_cats\\tf_cats_0.jpg',
 'tf_cats\\tf_cats_1.jpg',
 'tf_cats\\tf_cats_2.jpg',
 'tf_cats\\tf_cats_3.jpg',
 'tf_cats\\tf_cats_4.jpg']

讓我們僅將轉換應用於宇航員類別。 讓我們使用tf.image轉換。

def preprocess(filepath):
    category = tf.strings.split(filepath, os.sep)[0]
    read_file = tf.io.read_file(filepath)
    decode = tf.image.decode_jpeg(read_file, channels=3)
    resize = tf.image.resize(decode, (200, 200))
    image = tf.expand_dims(resize, 0)
    if tf.equal(category, 'tf_astronauts'):
        image = tf.image.flip_up_down(image)
        image = tf.image.flip_left_right(image)
    # image = tf.image.convert_image_dtype(image, tf.float32)
    # category = tf.cast(tf.equal(category, 'tf_astronauts'), tf.int32)
    return image, category

然后,我們制作tf.data.Dataset

train = tf.data.Dataset.from_tensor_slices(files).\
    shuffle(10).take(4).map(preprocess).batch(4)

當你迭代數據集時,你會看到只有宇航員被翻轉了:

fig = plt.figure()
plt.subplots_adjust(wspace=.1, hspace=.2)
images, labels = next(iter(train))
for index, (image, label) in enumerate(zip(images, labels)):
    ax = plt.subplot(2, 2, index + 1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(label.numpy().decode())
    ax.imshow(image[0].numpy().astype(int))
plt.show()

在此處輸入圖片說明

請注意,對於訓練,您需要在preprocess取消注釋這兩行,以便它返回一個浮點數和一個整數數組。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM