In Python, I'm moving from Keras's Model.fit
to a Model.train_on_batch
loop for finer control. But the progress bar and History object returned by fit
are useful. Before wasting time implementing them from scratch, I was wondering if anyone had found sample code using train_on_batch
that reproduced the progress bar and history?
(NB. I had a look at the source code for fit
, but there's enough layers of indirection that it's not easy to dig out exactly what it's doing. Also found this , which is helpful but doesn't have the relevant functionality.)
So after looking at source code of keras, I find the tf.keras.callbacks.ProgbarLogger and tf.keras.callbacks.History is what you want
Source code
Having defined EPOCHS
, train_generator
and validation data val_x, val_y
, you can replace
history = model.fit(train_generator, validation_data = (val_x, val_y), epochs = EPOCHS)
with the following code:
callbacks = tf.keras.callbacks.CallbackList(
None,
add_history = True,
add_progbar = True,
model = model,
epochs = EPOCHS,
verbose = 1,
steps = len(train_generator)
)
callbacks.on_train_begin()
for epoch in range(EPOCHS):
model.reset_metrics()
callbacks.on_epoch_begin(epoch)
for i in range(len(train_generator)):
callbacks.on_train_batch_begin(i)
logs = model.train_on_batch(*train_generator[i], reset_metrics = False, return_dict = True)
callbacks.on_train_batch_end(i, logs)
validation_logs = model.evaluate(val_x, val_y, callbacks = callbacks, return_dict = True)
logs.update({'val_' + name: v for name, v in validation_logs.items()})
callbacks.on_epoch_end(epoch, logs)
train_generator.on_epoch_end()
callbacks.on_train_end(epoch_logs)
history = model.history
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.