简体   繁体   中英

Difference between TensorFlow model fit and train_on_batch

I am building a vanilla DQN model to play the OpenAI gym Cartpole game.

However, in the training step where I feed in the state as input and the target Q values as the labels, if I use model.fit(x=states, y=target_q) , it works fine and the agent can eventually play the game well, but if I use model.train_on_batch(x=states, y=target_q) , the loss won't decrease and the model will not play the game anywhere better than a random policy.

I wonder what is the difference between fit and train_on_batch ? To my understanding, fit calls train_on_batch with a batch size of 32 under the hood which should make no difference since specifying the batch size to equal the actual data size I feed in makes no difference.

The full code is here if more contextual information is needed to answer this question: https://github.com/ultronify/cartpole-tf

model.fit will train 1 or more epochs. That means it will train multiple batches. model.train_on_batch , as the name implies, trains only one batch.

To give a concrete example, imagine you are training a model on 10 images. Let's say your batch size is 2. model.fit will train on all 10 images, so it will update the gradients 5 times. (You can specify multiple epochs, so it iterates over your dataset.) model.train_on_batch will perform one update of the gradients, as you only give the model on batch. You would give model.train_on_batch two images if your batch size is 2.

And if we assume that model.fit calls model.train_on_batch under the hood (though I don't think it does), then model.train_on_batch would be called multiple times, likely in a loop. Here's pseudocode to explain.

def fit(x, y, batch_size, epochs=1):
    for epoch in range(epochs):
        for batch_x, batch_y in batch(x, y, batch_size):
            model.train_on_batch(batch_x, batch_y)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM