简体   繁体   中英

How do I get a loss per epoch and not per batch?

In my understanding an epoch is an arbitrarily often repeated run over the whole dataset, which in turn is processed in parts, so called batches. After each train_on_batch a loss is calculated, the weights are updated and the next batch will get better results. These losses are indicators of the quality and learning state of my to NNs.

In several sources the loss is calculated (and printed) per epoch. Therefore I am not sure if I am doing this right.

At the moment my GAN looks like this:

for epoch:
  for batch:

    fakes = generator.predict_on_batch(batch)

    dlc = discriminator.train_on_batch(batch, ..)
    dlf = discriminator.train_on_batch(fakes, ..)
    dis_loss_total = 0.5 *  np.add(dlc, dlf)

    g_loss = gan.train_on_batch(batch,..)

    # save losses to array to work with later

These losses are for each batch. How do I get them for an epoch? As an aside: Do I need losses for an epoch, what for?

There is no direct way to compute the loss for an epoch. Actually, the loss of an epoch is usually defined as the average of the loss of batches in that epoch. So you can accumulate the loss values during an epoch and at the end divide it by the number of batches in the epoch:

epoch_loss = []
for epoch in range(n_epochs):
    acc_loss = 0.
    for batch in range(n_batches):
        # do the training 
        loss = model.train_on_batch(...)
        acc_loss += loss
    epoch_loss.append(acc_loss / n_batches)

As for the other question, one usage of epoch loss might be to use it as an indicator to stop the training (however, the validation loss is usually used for that, not the training loss).

I'll expand on @today answer a bit. There is a certain balance to strike in how to report loss over an epoch and how to use it to determine when training should stop.

  • If you only look at the loss of the most recent batch, it will be a very noisy estimate of your dataset loss because maybe that batch happened to store all the samples your model has trouble with, or all the samples that are trivial to succeed on.
  • If you look at the averaged loss over all batches in the epoch, you may get a skewed response because, like you indicated, the model has been (hopefully) improving over the epoch, so the performance on the initial batches aren't as meaningfully compared to the performance on the later batches.

The only way to accurately report your epoch loss is to take your model out of training mode, ie fix all the model parameters, and run your model on the whole dataset. That will be an unbiased computation of your epoch loss. However, in general that's a terrible idea because if you have a complex model or a lot of training data, you will waste a lot of time doing this.

So, I think it's most common to balance these factors by reporting an averaged loss over N mini-batches, where N is large enough to smooth out the noise of individual batches but not so large that the model performance is not comparable between the first and last batches.

I know you're in Keras but here is a PyTorch example that illustrates this concept clearly, replicated here:

for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

You can see they accumulate the loss over N =2000 batches, report the averaged loss over those 2000 batches, then zero out the running loss and keep going.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM