简体   繁体   中英

Using Checkpoint saving with train_on_batch in Keras

I'm training my data in batches using train_on_batch , but it seems train_on_batch doesn't have an option to use callbacks, which seems to be a requirement to use checkpoints.

I can't use model.fit as that seems to require I load all of my data into memory.

model.fit_generator is giving me strange problems (like hanging at end of an epoch).

Here is the example from Keras API docs showing the use of ModelCheckpoint :

from keras.callbacks import ModelCheckpoint

model = Sequential() 
model.add(Dense(10, input_dim=784, kernel_initializer='uniform'))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')

checkpointer = ModelCheckpoint(filepath='/tmp/weights.hdf5', verbose=1, 
                               save_best_only=True)
model.fit(x_train, y_train, batch_size=128, epochs=20, 
          verbose=0, validation_data=(X_test, Y_test), callbacks=[checkpointer])

If you train on each batch manually, you can do whatever you want at any #epoch(#batch). No need to use callback, just call model.save or model.save_weights .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM