简体   繁体   中英

What is causing large jumps in training accuracy and loss between epochs?

In training a neural network in Tensorflow 2.0 in python, I'm noticing that training accuracy and loss change dramatically between epochs. I'm aware that the metrics printed are an average over the entire epoch, but accuracy seems to drop significantly after each epoch, despite the average always increasing.

The loss also exhibits this behavior, dropping significantly each epoch but the average increases. Here is an image of what I mean (from Tensorboard):

奇怪的训练行为

I've noticed this behavior on all of the models I've implemented myself, so it could be a bug, but I want a second opinion on whether this is normal behavior and if so what does it mean?

Also, I'm using a fairly large dataset (roughly 3 million examples). Batch size is 32 and each dot in the accuracy/loss graphs represent 50 batches (2k on the graph = 100k batches). The learning rate graph is 1:1 for batches.

I have just newly experienced this kind of issue while I was working on a project that is about object localization. For my case, there was three main candidates.

  • I have used no shuffling in my training. That creates a loss increase after each epoch.

  • I have defined a new loss function that is calculated using IOU. It was something like;

     def new_loss(y_true, y_pred): mse = tf.losses.mean_squared_error(y_true, y_pred) iou = calculate_iou(y_true, y_pred) return mse + (1 - iou)

    I also suspect this loss may be a possible candidate of increase in loss after epoch. However, I was not able to replace it.

  • I was using an Adam optimizer. So, a possible thing to do is to change it to see how the training affected.

Conclusion

I have just changed the Adam to SGD and shuffled my data in training. There was still a jump in the loss but it was so minimal compared without a change. For example, my loss spike was ~0.3 before the changes and it became ~0.02.

Note

I need to add there are lots of discussions about this topic. I tried to utilize the possible solutions that are possible candidates for my model.

It seems this phenomenon comes from the fact that the model has a high batch-to-batch variance in terms of accuracy and loss. This is illustrated if I take a graph of the model with the actual metrics per step as opposed to the average over the epoch:

在此处输入图像描述

Here you can see that the model can vary widely. (This graph is just for one epoch, but the fact remains).

Since the average metrics were being reported per epoch, at the beginning of the next epoch it is highly likely that the average metrics will be lower than the previous average, leading to a dramatic drop in the running average value, illustrated in red below:

在此处输入图像描述

If you imagine the discontinuities in the red graph as being epoch transitions, you can see why you would observe the phenomenon in the question.

TL;DR The model has a very high variance in it's output with respect to each batch.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM