简体   繁体   中英

What does train_on_batch() do in keras model?

I saw a sample of code (too big to paste here) where the author used model.train_on_batch(in, out) instead of model.fit(in, out) . The official documentation of Keras says:

Single gradient update over one batch of samples.

But I don't get it. Is it the same as fit() , but instead of doing many feed-forward and backprop steps, it does it once? Or am I wrong?

Yes, train_on_batch trains using a single batch only and once.

While fit trains many batches for many epochs. (Each batch causes an update in weights).

The idea of using train_on_batch is probably to do more things yourself between each batch.

It is used when we want to understand and do some custom changes after each batch training.

A more precide use case is with the GANs. You have to update discriminator but during update the GAN network you have to keep the discriminator untrainable. so you first train the discriminator and then train the gan keeping discriminator untrainable. see this for more understanding: https://medium.com/datadriveninvestor/generative-adversarial-network-gan-using-keras-ce1c05cfdfd3

The method fit of the model train the model for one pass through the data you gave it, however because of the limitations in memory (especially GPU memory), we can't train on a big number of samples at once, so we need to divide this data into small piece called mini-batches (or just batchs). The methode fit of keras models will do this data dividing for you and pass through all the data you gave it.

However, sometimes we need more complicated training procedure we want for example to randomly select new samples to put in the batch buffer each epoch (eg GAN training and Siamese CNNs training ...), in this cases we don't use the fancy an simple fit method but instead we use the train_on_batch method. To use this methode we generate a batch of inputs and a batch of outputs(labels) in each iteration and pass it to this method and it will train the model on the whole samples in the batch at once and gives us the loss and other metrics calculated with respect to the batch samples.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM