繁体   English   中英

我如何在每个时期而不是每个批次中获得损失?

[英]How do I get a loss per epoch and not per batch?

在我的理解中,epoch 是在整个数据集上任意频繁地重复运行,然后分部分处理,即所谓的批处理。 在每次train_on_batch计算一个损失后,更新权重,下一批将获得更好的结果。 这些损失是我对神经网络质量和学习状态的指标。

在几个来源中,每个时期都会计算(并打印)损失。 因此,我不确定我这样做是否正确。

目前我的 GAN 看起来像这样:

for epoch:
  for batch:

    fakes = generator.predict_on_batch(batch)

    dlc = discriminator.train_on_batch(batch, ..)
    dlf = discriminator.train_on_batch(fakes, ..)
    dis_loss_total = 0.5 *  np.add(dlc, dlf)

    g_loss = gan.train_on_batch(batch,..)

    # save losses to array to work with later

这些损失是针对每个批次的。 我如何在一个时代获得它们? 顺便说一句:我需要一个时代的损失,为了什么?

没有直接的方法来计算一个时期的损失。 实际上,一个 epoch 的损失通常定义为该 epoch 中批次损失的平均值。 因此,您可以在一个 epoch 期间累积损失值,并在最后将其除以该 epoch 中的批次数:

epoch_loss = []
for epoch in range(n_epochs):
    acc_loss = 0.
    for batch in range(n_batches):
        # do the training 
        loss = model.train_on_batch(...)
        acc_loss += loss
    epoch_loss.append(acc_loss / n_batches)

至于另一个问题,epoch loss 的一个用途可能是将其用作停止训练的指标(但是,通常使用验证损失,而不是训练损失)。

我会稍微扩展@today 的回答。 在如何报告一个时期的损失以及如何使用它来确定何时应该停止训练之间存在一定的平衡。

  • 如果您只查看最近批次的损失,那么对数据集损失的估计将是一个非常嘈杂的估计,因为该批次可能恰好存储了您的模型遇到问题的所有样本,或者所有成功的微不足道的样本.
  • 如果您查看 epoch 中所有批次的平均损失,您可能会得到一个偏斜的响应,因为正如您所指出的,该模型已经(希望)在 epoch 上有所改进,因此初始批次的性能没有那么有意义与后面批次的性能相比。

准确报告您的 epoch 损失的唯一方法是让您的模型退出训练模式,即修复所有模型参数,并在整个数据集上运行您的模型。 这将是对 epoch 损失的无偏计算。 然而,总的来说,这是一个糟糕的主意,因为如果您有一个复杂的模型或大量的训练数据,您将浪费大量时间这样做。

因此,我认为最常见的是通过报告N个小批量的平均损失来平衡这些因素,其中N大到足以消除单个批次的噪音,但又不会太大以至于模型性能在第一个和最后一批。

我知道你在使用 Keras,但这里有一个 PyTorch 示例,它清楚地说明了这个概念,复制在这里:

for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

你可以看到他们累积了N = 2000 个批次的损失,报告了这 2000 个批次的平均损失,然后将运行损失归零并继续前进。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM