簡體   English   中英

model.fit_generator 中的 steps_per_epoch 實際上在做什么?

[英]What is steps_per_epoch in model.fit_generator actually doing?

在閱讀了model.fit_generator方法中有關steps_per_epoch所需參數的 Keras 文檔后,我對它的理解是:

如果數據集包含“N”個樣本,並且生成器 function(傳遞給 Keras)為每個調用返回“B = batch_size”樣本數(在這里,我認為調用是生成器函數的單個產量)並且因為 steps_per_epoch = ceil(N/B) 生成器被調用steps_per_epoch次,以便完整數據集在一個 epoch 后通過 model,並且每個 epoch 都重復相同的過程,直到訓練完成。

為了測試我的理解是否正確,我實現了以下

import numpy as np
from keras.models import Sequential
from keras.layers import Dense

index = 0

def get_values(inputs, targets):
    i = 0
    while True:
        yield inputs[i], targets[i]
        i += 1
        if i >= len(inputs):
            i = 0


def get_batch(inputs, targets, batch_size=2):
    global index
    batch_X = []
    batch_Y = []
    for inp, targ in get_values(inputs, targets):
        batch_X.append(inp)
        batch_Y.append(targ)

        if len(batch_X) >= batch_size:
            yield np.array(batch_X), np.array(batch_Y)
            index += 1
            batch_X = []
            batch_Y = []


data = list(range(10))
labels = [2*val for val in range(10)]

model = Sequential([
    Dense(16, activation='relu', input_shape=(1, )),
    Dense(1)
])

model.compile(optimizer='rmsprop', loss='mean_squared_error')
model.fit_generator(get_batch(data, labels, batch_size=2), steps_per_epoch=5, epochs=1, verbose=False)

print(index) # Should Print 5 but it prints 15

該程序並不難理解...

但根據我的解釋,它應該打印 5 但它打印 15。我對steps_per_epoch的解釋錯了嗎? 如果是這樣,請給我steps_per_epoch的正確解釋

PS。 我是 Tensorflow 和 Keras 的新手,在此先感謝。

沒有通過您的代碼 go 但您的原始解釋是正確的。 實際上,根據位於此處的文檔,您可以省略每個時期的步驟,並且 model.fit 會將數據集的長度 (N) 除以批處理大小來確定步驟。 我確實復制並運行了您的代碼。 猜猜它把索引打印為 5。我能想到的唯一可能不同的是進口。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM