簡體   English   中英

自定義Keras生成器比Keras的bult生成器慢得多

[英]Custom Keras generator much slower compared to Keras' bult in generator

我有一個多標簽分類問題。 我寫了這個自定義生成器。 它從磁盤讀取圖像和輸出標簽,並以32的大小批量返回它們。

def get_input(img_name):
    path = os.path.join("images", img_name)
    img = image.load_img(path, target_size=(224, 224))

    return img


def get_output(img_name, file_path):
    data = pd.read_csv(file_path, delim_whitespace=True, header=None)

    img_id = img_name.split(".")[0]
    img_id = img_id.lstrip("0")
    img_id = int(img_id)

    labels = data.loc[img_id - 1].values
    labels = labels[1:]

    labels = list(labels)
    label_arrays = []
    for i in range(20):
        val = np.zeros((1))
        val[0] = labels[i]
        label_arrays.append(val)

    return label_arrays


def preprocess_input(img_name):
    img = get_input(img_name)
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)    
    return x

def train_generator(batch_size):
    file_path = "train.txt"
    data = pd.read_csv(file_path, delim_whitespace=True, header=None)

    while True:
        for i in range(math.floor(8000/batch_size)):
            x_batch = np.zeros(shape=(32, 224, 224, 3))
            y_batch = np.zeros(shape=(32, 20))
            for j in range(batch_size):
                img_name = data.loc[i * batch_size + j].values
                img_name = img_name[0]
                x = preprocess_input(img_name)
                y = get_output(img_name, file_path)
                x_batch[j, :, :, :] = x
                y_batch[j] = y

            ys = []
            for i in range(20):
              ys.append(y_batch[:,i])

            yield(x_batch, ys)

標簽返回模型有一個小問題,並在以下問題中得到解決: 訓練多輸出keras模型

我在單個輸出問題上測試了此生成器。 此自定義生成器非常慢。 使用此自定義生成器的單個時間段的預計到達時間約為27小時,而內置生成器(使用flow_from_directory)單個時間段則需要25分鍾。 我究竟做錯了什么?

除使用的發電機外,兩個測試的訓練過程相同。 驗證生成器類似於訓練生成器。 我知道我無法達到Keras內置發電機的效率,但是這種速度差異太大。

編輯

我閱讀了一些有關創建自定義生成器的指南。

編寫定制的Keras生成器

fit_generator()的自定義生成器,生成具有不同形狀的多個輸入

也許內置的生成器會在您的gpu上處理數據,而您的自定義生成器則在cpu上運行,這會大大降低速度。

另一個猜測是因為Keras在后台使用數據集 您的實現可能使用feed-dict ,這是將信息傳遞給TensorFlow的最慢方法。 將數據輸入模型的最好方法是使用輸入管道,以確保GPU永遠不必等待新的東西進入。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM