簡體   English   中英

在 keras 中針對不同批次大小訓練 model

[英]training model for different batch sizes in keras

我想針對不同的批量大小訓練我的 model,即:[64, 128] 我正在使用如下的 for 循環

   epoch=2 
   batch_sizes = [128,256] 
   for i in range(len(batch_sizes)):
     history = model.fit(x_train, y_train, batch_sizes[i], epochs=epochs, 
          callbacks=[early_stopping, chk], validation_data=(x_test, y_test))

對於上面的代碼,我的 model 產生以下結果:

    Epoch 1/2
    311/311 [==============================] - 157s 494ms/step - loss: 0.2318 - 
    f1: 0.0723 
    Epoch 2/2
    311/311 [==============================] - 152s 488ms/step - loss: 0.1402 - 
    f1: 0.4360 

    Epoch 1/2
    156/156 [==============================] - 137s 877ms/step - loss: 0.1197 - 
    f1: **0.5450** 
    Epoch 2/2
    156/156 [==============================] - 136s 871ms/step - loss: 0.1132 - 
    f1: 0.5756

看起來 model 在完成批量 64 的訓練后繼續訓練,即我想讓我的 model 從頭開始訓練下一批,我該怎么做,請指導我。 ps:我嘗試過的:

   epoch=2 
   batch_sizes = [128,256] 
   for i in range(len(batch_sizes)):
     history = model.fit(x_train, y_train, batch_sizes[i], epochs=epochs, 
          callbacks=[early_stopping, chk], validation_data=(x_test, y_test))
   keras.backend.clear_session()

它也沒有奏效

您可以編寫 function 來定義 model,並且您需要在隨后的fit調用之前調用它。 如果您的 model 包含在model中,則權重會在訓練期間更新,並且在 fit 調用后保持不變。 這就是為什么您需要重新定義 model。 這可以幫助你

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import numpy as np

X = np.random.rand(1000,5)
Y = np.random.rand(1000,1)

def build_model():
    model = Sequential()
    model.add(Dense(64,input_shape=(X.shape[1],)))
    model.add(Dense(Y.shape[1]))
    model.compile(loss='mse',optimizer='Adam')
    return model

epoch=2
batch_sizes = [128,256]
for i in range(len(batch_sizes)):
    model = build_model()
    history = model.fit(X, Y, batch_sizes[i], epochs=epoch, verbose=2)
    model.save('Model_' + str(batch_sizes[i]) + '.h5')

然后,output 看起來像:

Epoch 1/2
8/8 - 0s - loss: 0.3164
Epoch 2/2
8/8 - 0s - loss: 0.1367
Epoch 1/2
4/4 - 0s - loss: 0.7221
Epoch 2/2
4/4 - 0s - loss: 0.4787

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM