簡體   English   中英

Keras 模型的 predict 和 predict_on_batch 方法有什么區別?

[英]What is the difference between the predict and predict_on_batch methods of a Keras model?

根據 keras文檔

predict_on_batch(self, x)
Returns predictions for a single batch of samples.

但是,無論是使用一個元素還是多個元素,在批處理上調用時,與標准的predict方法似乎沒有任何區別。

model.predict_on_batch(np.zeros((n, d_in)))

是相同的

model.predict(np.zeros((n, d_in)))

(一個numpy.ndarray形狀的(n, d_out

不同之處在於當您將x數據傳遞給大於一批的數據時。

predict逐批遍歷所有數據,預測標簽。 因此,它在內部分批進行拆分並一次進料一批。

另一方面, predict_on_batch假設您傳入的數據恰好是一個批次,因此將其提供給網絡。 它不會嘗試拆分它(根據您的設置,如果陣列非常大,這可能會對您的 GPU 內存造成問題)

我只想添加一些不適合評論的內容。 似乎predict仔細檢查輸出形狀:

class ExtractShape(keras.engine.topology.Layer):
    def call(self, x):
        return keras.backend.sum(x, axis=0)
    def compute_output_shape(self, input_shape):
        return input_shape

a = keras.layers.Input((None, None))
b = ExtractShape()(a)
m = keras.Model(a, b)
m.compile(optimizer=keras.optimizers.Adam(), loss='binary_crossentropy')
A = np.ones((5,4,3))

然后:

In [163]: m.predict_on_batch(A)
Out[163]: 
array([[5., 5., 5.],
       [5., 5., 5.],
       [5., 5., 5.],
       [5., 5., 5.]], dtype=float32)
In [164]: m.predict_on_batch(A).shape
Out[164]: (4, 3)

但:

In [165]: m.predict(A)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-165-c5ba5fc88b6e> in <module>()

----> 1 m.predict(A)

~/miniconda3/envs/ccia/lib/python3.6/site-packages/keras/engine/training.py in predict(self, x, batch_size, verbose, steps)
   1746         f = self.predict_function
   1747         return self._predict_loop(f, ins, batch_size=batch_size,
-> 1748                                   verbose=verbose, steps=steps)
   1749 
   1750     def train_on_batch(self, x, y,

~/miniconda3/envs/ccia/lib/python3.6/site-packages/keras/engine/training.py in _predict_loop(self, f, ins, batch_size, verbose, steps)
   1306                         outs.append(np.zeros(shape, dtype=batch_out.dtype))
   1307                 for i, batch_out in enumerate(batch_outs):
-> 1308                     outs[i][batch_start:batch_end] = batch_out
   1309                 if verbose == 1:
   1310                     progbar.update(batch_end)

ValueError: could not broadcast input array from shape (4,3) into shape (5,3)

我不確定這是否真的是一個錯誤。

與在單個批次上執行的 predict_on_batch 相比,似乎 predict_on_batch 快得多。

  • 批次和型號信息
    • 批次形狀:(1024, 333)
    • 批處理數據類型:float32
    • 模型參數:~150k
  • 時間結果:
    • 預測: ~1.45 秒
    • predict_on_batch: ~95.5 毫秒

總之, predict 方法有額外的操作來確保正確處理一組批次,而 predict_on_batch 是一個輕量級的替代方法來預測應該在單個批次上使用。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM