[英]How can I get a stateful LSTM to reset its states between epochs during a Keras Tuner search?
我正在嘗試使用 Keras Tuner 調整有狀態 LSTM。 我的代碼工作正常,它能夠訓練模型,但我仍然不知道如何讓 model 重置時期之間的狀態。 通常我會在一個循環中一次訓練 1 個 epoch 並在 epoch 之間手動 reset_states。 但是,我認為 Keras 調諧器甚至不可能做到這一點。 有沒有我可以使用的論點來實現這一點? 這是我當前的調諧器代碼如下:
def build_model(hp):
model = Sequential()
model.add(layers.Input(batch_shape=(batch_size,train_X.shape[1], train_X.shape[2])))
for i in range(hp.Int('num_LSTM_layers', 1, 3)):
model.add(layers.LSTM(units=hp.Int('units_' + str(i),min_value=32,max_value=512,step=4),
batch_input_shape=(batch_size,train_X.shape[1], train_X.shape[2]),
activation=hp.Choice('LSTM_activation_' + str(i),values=['relu','softplus',
'tanh', 'sigmoid','softsign','selu','elu','linear'],
default='elu'),return_sequences=True,stateful=True))
for j in range(hp.Int('num_dense_layers', 1, 3)):
model.add(layers.Dense(units=hp.Int('units_' + str(i),min_value=64,max_value=1024,step=4),
activation=hp.Choice('dense_activation_' + str(i),values=['relu','softplus',
'tanh', 'sigmoid','softsign','selu','elu','linear'],
default='elu')))
model.add(layers.Dropout(rate=hp.Float('rate_' + str(i), min_value=0.01, max_value=0.50, step=0.01)))
model.add(layers.Dense(train_y.shape[1],activation='linear'))
model.compile(
optimizer=hp.Choice('optimizers',values=['rmsprop','adam','adadelta','Nadam']),
loss='mse',metrics=['mse'])
return model
tuner_bo = BayesianOptimization(
build_model,
objective='val_loss',
max_trials=50,
executions_per_trial=3,overwrite=True,num_initial_points=10,
directory=model_path,project_name='LSTM_KT_2001',
allow_new_entries=True,tune_new_entries=True)
tuner_bo.search_space_summary()
tuner_bo.search(train_X, train_y, epochs=100,batch_size=1,validation_data=(test_X,test_y), verbose=2)
我覆蓋了調諧器 class 中的 on_epoch_end 方法,不確定該方法是否正確。
class MyBayesianOptimization(BayesianOptimization):
def on_epoch_end(self, trial, my_hyper_model, epoch, logs=None):
my_hyper_model.reset_states()
super(MyBayesianOptimization, self).on_epoch_end(trial, my_hyper_model, epoch, logs)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.