簡體   English   中英

我可以使用Keras中的ImageDataGenerator()和flow_from_directory()生成uint8標簽嗎?

[英]Can I generate uint8 label using ImageDataGenerator() and flow_from_directory() in Keras?

我正在處理2D語義分段任務。

在Keras API文檔中,這些僅包含示例,顯示如何為圖像分類安排數據集,而不是語義分段。

所以我像這樣安排我的圖像和標簽

SEED = 111
batch_size = 2
image_datagen = ImageDataGenerator(
    horizontal_flip=True,
    zca_epsilon=9,
    # fill_mode='nearest',
)
image_generator = image_datagen.flow_from_directory(
    directory="/xxx/images",
    class_mode=None,
    batch_size=batch_size,
    seed=SEED,
)


def preprocessing_function(image):
    return image.astype(np.uint8)


label_datagen = ImageDataGenerator(
    horizontal_flip=True,
    zca_epsilon=9,
    rescale=1,
    preprocessing_function=preprocessing_function,
    # fill_mode='nearest',
)
label_generator = image_datagen.flow_from_directory(
    directory="/xxx/labels",
    class_mode=None,
    batch_size=batch_size,
    seed=SEED,
)

train_generator = zip(image_generator, label_generator)
print(len(image_generator))
i = 0
for image_batch, label_batch in iter(train_generator):
    print(image_batch.shape, label_batch.shape) # (2, 256, 256, 3) (2, 256, 256, 3)
    print(image_batch.dtype, label_batch.dtype) # float32 float32
    i += 1
    if i == 5:
        break

但是我發現生成的標簽圖像的類型是float32 ,所以我將一個preprocessing_function函數添加到label_datagen只是為了將dtype轉換為uint8, 但生成的標簽圖像'dtype仍然是float32 ,似乎preprocessing_function什么也沒做。

我該如何修復這個問題?

如何將我的標簽數據更改為uint8?

添加預處理函數來投射標簽圖像的dtype是“常見做法”嗎?

謝謝你的建議!

我遇到了同樣的問題,並將發電機包裝成另一個。 它有效,但它是一種kludge

label_generator = (x.astype(np.uint8) for x in label_generator)
train_generator = zip(image_generator, label_generator)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM