简体   繁体   中英

Epochs and batches control in Keras

I would like to implement an autoencoder model that acts as following:

for epoch in xrange(100):
  for X_batch in batch_list:
     model.train_on_batch(X_batch, X_batch)
     training_error = model.evaluate(X_batch, X_batch, verbose=0)
  "average the training error by the number of the batches considered"
  "save it as the epoch training error"
  "call the function to get the validation error in the same fashion over  the validation data"
  "compare the two errors and decide whether go on training or stopping"

By looking around fit_generator seemed an option but I did not understand how to use it. Should I instead use the train_on_batch or the fit with just one epoch to properly fit the model?

Which is the best practice for such case?

From what I can understand, you want to use validation error as early stopping criteria. Good news is that keras already has early stopping callback. SO all you need is to create a callback and call it during training after some epochs/iterations.

keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=0, verbose=0, mode='auto', baseline=None, restore_best_weights=False)

Let us look at train_on_batch and fit()

train_on_batch(x, y, sample_weight=None, class_weight=None)


fit(x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None)

You can see that train_on_batch doesn't take any callback as input, so good choice is to use fit here, unless you want to implement it by yourself.

Now you can call fit as following

callbacks = [EarlyStopping(monitor='val_loss', patience=2),
         ModelCheckpoint(filepath='path to latest ckpt', monitor='val_loss', save_best_only=True)]

history = model.fit(train_features,train_target, epochs=num_epochs, callbacks=callbacks, verbose=0, batch_size=your_choice, validation_data) 

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM