简体   繁体   English

Keras 中的时代和批次控制

[英]Epochs and batches control in Keras

I would like to implement an autoencoder model that acts as following:我想实现一个自动编码器模型,其作用如下:

for epoch in xrange(100):
  for X_batch in batch_list:
     model.train_on_batch(X_batch, X_batch)
     training_error = model.evaluate(X_batch, X_batch, verbose=0)
  "average the training error by the number of the batches considered"
  "save it as the epoch training error"
  "call the function to get the validation error in the same fashion over  the validation data"
  "compare the two errors and decide whether go on training or stopping"

By looking around fit_generator seemed an option but I did not understand how to use it.通过环顾fit_generator似乎是一个选项,但我不明白如何使用它。 Should I instead use the train_on_batch or the fit with just one epoch to properly fit the model?我应该使用train_on_batch还是仅使用一个时期的fit来正确拟合模型?

Which is the best practice for such case?对于这种情况,哪种最佳做法是?

From what I can understand, you want to use validation error as early stopping criteria.据我所知,您希望使用验证错误作为提前停止标准。 Good news is that keras already has early stopping callback.好消息是 keras 已经有提前停止回调。 SO all you need is to create a callback and call it during training after some epochs/iterations.所以你所需要的就是创建一个回调并在一些时期/迭代后的训练期间调用它。

keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=0, verbose=0, mode='auto', baseline=None, restore_best_weights=False)

Let us look at train_on_batch and fit()让我们看看 train_on_batch 和 fit()

train_on_batch(x, y, sample_weight=None, class_weight=None)


fit(x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None)

You can see that train_on_batch doesn't take any callback as input, so good choice is to use fit here, unless you want to implement it by yourself.可以看到 train_on_batch 不接受任何回调作为输入,所以这里使用 fit 是不错的选择,除非你想自己实现它。

Now you can call fit as following现在你可以调用 fit 如下

callbacks = [EarlyStopping(monitor='val_loss', patience=2),
         ModelCheckpoint(filepath='path to latest ckpt', monitor='val_loss', save_best_only=True)]

history = model.fit(train_features,train_target, epochs=num_epochs, callbacks=callbacks, verbose=0, batch_size=your_choice, validation_data) 

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM