簡體   English   中英

如何以有效的內存方式將圖像數據集讀入numpy數組?

[英]How can I read in image dataset as numpy array in a memory efficient manner?

我試圖將圖像數據集加載為numpy數組。 我怎么能這樣做,以便我不強調本地機器上RAM的限制,或創建一個需要太多內存的數組? 較大的圖像集是訓練集,總共約2GB的圖像。

這是為了訓練需要輸入數據為numpy數組的殘余神經網絡。 我已經使用了模塊glob,PIL,skimage,sklearn和numpy來嘗試加載圖像,但是我這樣做很可能是天真的,因為~2GB的圖像變成了~17(!)GB numpy數組。 我已經嘗試過搜索解決方案,示例等等,但對Python來說這是一個新手,所以這個過程非常慢。

用於天真加載圖像的代碼是

import glob
from skimage.transform import resize
import numpy as np
from sklearn import datasets
from PIL import Image

def root_2_numpy(data_root):
    """
    Load raw images and output a numpy array of all images and numpy array of labels
    Also preprocesses each image to (224,224) using anti-aliasing
    """
    # load images into numpy array
    all_image_paths = list(data_root.glob('*/*'))  # get image paths
    all_image_paths = [str(path) for path in all_image_paths]  # convert to string
    image_ds = np.zeros([len(all_image_paths), 224, 224,3])  # initialize image dataset
    for i in range(len(all_image_paths)):
        print(i)
        im = Image.open(all_image_paths[i])  # read image as RGB using matplotlib
        if im.mode == 'RGBA' or im.mode == 'L' or im.mode == 'CMYK':
            im = im.convert('RGB')
        elif im.mode =='P':
            im = im.convert('RGBA')
            im = im.convert('RGB')
        im = np.array(im)
        im = resize(im, (224,224), anti_aliasing=True)  # resize image using skimage
        image_ds[i,:,:,:] = im

    # load labels into numpy array
    label_ds = datasets.load_files(data_root, load_content=False, shuffle=False)  # get labels
    n_classes = len(label_ds.target_names)
    Y_ds = np.eye(len(label_ds.target_names))[label_ds.target.reshape(-1)]

    return image_ds, Y_ds, n_classes

我希望這會返回一個約2GB的numpy數組,它具有圖像數量,圖像寬度,圖像高度和圖像的3個通道的尺寸(N,W,H,C)。 這不是問題所在,但我也希望有標簽的數據,它們是根目錄中的類別名稱。

除了幫助我有效地加載數據之外,我非常希望能夠深入了解我的代碼如何創建如此大的numpy數組。 在我寫這篇文章時,我感覺在轉換非RBG圖像的圖像類型時可能會產生比預期更多的圖像。

numpy.zeros創建的數組的默認數據類型是64位浮點。 所以image_ds = np.zeros([len(all_image_paths), 224, 224,3])會創建一個比你需要的大8倍的數組。 添加dtype參數,以便image_ds具有數據類型uint8 (8位無符號整數):

image_ds = np.zeros([len(all_image_paths), 224, 224,3], dtype=np.uint8)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM