繁体   English   中英

model.train() 在 PyTorch 中做了什么?

[英]What does model.train() do in PyTorch?

它是否在nn.Module中调用forward() 我想当我们调用模型时,正在使用forward方法。 为什么我们需要指定 train()?

model.train()告诉您的模型您正在训练模型。 如此有效的层,如 dropout、batchnorm 等,它们在火车上表现不同,测试程序知道发生了什么,因此可以相应地表现。

更多细节:它将模式设置为训练(参见源代码)。 您可以调用model.eval()model.train(mode=False)来告诉您正在测试。 期望train函数训练模型有点直观,但事实并非如此。 它只是设置模式。

这是module.train()的代码:

def train(self, mode=True):
        r"""Sets the module in training mode."""      
        self.training = mode
        for module in self.children():
            module.train(mode)
        return self

这是module.eval

def eval(self):
        r"""Sets the module in evaluation mode."""
        return self.train(False)

模式traineval是我们可以设置模块的仅有的两种模式,它们完全相反。

这只是一个self.training标志,目前只有DropoutBatchNorm关心该标志。

默认情况下,此标志设置为True

model.train() model.eval()
将模型设置为训练模式,即

BatchNorm层使用每批统计数据
Dropout层激活
在评估(推理)模式下设置模型,即

BatchNorm层使用运行统计数据
Dropout层停用等
等效于model.train(False)

注意:这些函数调用都不运行前向/后向传递。 他们告诉模型在运行时如何行动。

这一点很重要,因为某些模块(层) (例如DropoutBatchNorm )被设计为在训练和推理期间表现不同,因此如果在错误模式下运行,模型将产生意想不到的结果。

有两种方法可以让模型知道您的意图,即您是要训练模型还是要使用模型进行评估。 model.train()的情况下,模型知道它必须学习层,当我们使用model.eval()时,它指示模型不需要学习任何新内容,并且模型用于测试。 model.eval()也是必要的,因为在 pytorch 中,如果我们使用 batchnorm 并且在测试期间如果我们只想传递单个图像,如果未指定model.eval() ,pytorch 会抛出错误。

考虑以下模型

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GraphNet(torch.nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GraphNet, self).__init__()
        self.conv1 = GCNConv(num_node_features, 16)
        self.conv2 = GCNConv(16, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.dropout(x, training=self.training) #Look here
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

在这里, dropout的功能在不同的操作模式下有所不同。 如您所见,它仅在self.training==True时有效。 因此,当您键入model.train()时,模型的 forward 函数将执行 dropout,否则它不会(例如在model.eval()model.train(mode=False)时)。

当前的官方文档声明如下:

这仅对某些模块有任何 [原文如此] 影响。 如果它们受到影响,请参阅特定模块的文档以了解它们在训练/评估模式下的行为的详细信息,例如 Dropout、BatchNorm 等。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM