简体   繁体   中英

Keras: is there sample code for train_on_batch which has history + progress?

In Python, I'm moving from Keras's Model.fit to a Model.train_on_batch loop for finer control. But the progress bar and History object returned by fit are useful. Before wasting time implementing them from scratch, I was wondering if anyone had found sample code using train_on_batch that reproduced the progress bar and history?

(NB. I had a look at the source code for fit , but there's enough layers of indirection that it's not easy to dig out exactly what it's doing. Also found this , which is helpful but doesn't have the relevant functionality.)

So after looking at source code of keras, I find the tf.keras.callbacks.ProgbarLogger and tf.keras.callbacks.History is what you want

Source code

keras/callbacks.py#L259

keras/callbacks.py#L263

Having defined EPOCHS , train_generator and validation data val_x, val_y , you can replace

history = model.fit(train_generator, validation_data = (val_x, val_y), epochs = EPOCHS)

with the following code:

callbacks = tf.keras.callbacks.CallbackList(
    None, 
    add_history = True,
    add_progbar = True,
    model = model,
    epochs = EPOCHS,
    verbose = 1,
    steps = len(train_generator)
)

callbacks.on_train_begin()
for epoch in range(EPOCHS):
    model.reset_metrics()
    callbacks.on_epoch_begin(epoch)
    for i in range(len(train_generator)):
        callbacks.on_train_batch_begin(i)
        logs = model.train_on_batch(*train_generator[i], reset_metrics = False, return_dict = True)              
        callbacks.on_train_batch_end(i, logs)

    validation_logs = model.evaluate(val_x, val_y, callbacks = callbacks, return_dict = True)
    logs.update({'val_' + name: v for name, v in validation_logs.items()})

    callbacks.on_epoch_end(epoch, logs)
    train_generator.on_epoch_end()

callbacks.on_train_end(epoch_logs)
history = model.history

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM