[英]What does model.train() function do in tensorlayer Seq2seq model
[英]What does model.train() do in PyTorch?
它是否在nn.Module
中调用forward()
? 我想当我们调用模型时,正在使用forward
方法。 为什么我们需要指定 train()?
model.train()
告诉您的模型您正在训练模型。 如此有效的层,如 dropout、batchnorm 等,它们在火车上表现不同,测试程序知道发生了什么,因此可以相应地表现。
更多细节:它将模式设置为训练(参见源代码)。 您可以调用model.eval()
或model.train(mode=False)
来告诉您正在测试。 期望train
函数训练模型有点直观,但事实并非如此。 它只是设置模式。
这是module.train()
的代码:
def train(self, mode=True):
r"""Sets the module in training mode."""
self.training = mode
for module in self.children():
module.train(mode)
return self
这是module.eval
。
def eval(self):
r"""Sets the module in evaluation mode."""
return self.train(False)
模式train
和eval
是我们可以设置模块的仅有的两种模式,它们完全相反。
这只是一个self.training
标志,目前只有Dropout
和BatchNorm
关心该标志。
默认情况下,此标志设置为True
。
model.train() |
model.eval() |
---|---|
将模型设置为训练模式,即 • BatchNorm 层使用每批统计数据• Dropout 层激活等 |
在评估(推理)模式下设置模型,即 • BatchNorm 层使用运行统计数据• Dropout 层停用等 |
等效于model.train(False) 。 |
注意:这些函数调用都不运行前向/后向传递。 他们告诉模型在运行时如何行动。
这一点很重要,因为某些模块(层) (例如Dropout
、 BatchNorm
)被设计为在训练和推理期间表现不同,因此如果在错误模式下运行,模型将产生意想不到的结果。
有两种方法可以让模型知道您的意图,即您是要训练模型还是要使用模型进行评估。 在model.train()
的情况下,模型知道它必须学习层,当我们使用model.eval()
时,它指示模型不需要学习任何新内容,并且模型用于测试。 model.eval()
也是必要的,因为在 pytorch 中,如果我们使用 batchnorm 并且在测试期间如果我们只想传递单个图像,如果未指定model.eval()
,pytorch 会抛出错误。
考虑以下模型
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class GraphNet(torch.nn.Module):
def __init__(self, num_node_features, num_classes):
super(GraphNet, self).__init__()
self.conv1 = GCNConv(num_node_features, 16)
self.conv2 = GCNConv(16, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.dropout(x, training=self.training) #Look here
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
在这里, dropout
的功能在不同的操作模式下有所不同。 如您所见,它仅在self.training==True
时有效。 因此,当您键入model.train()
时,模型的 forward 函数将执行 dropout,否则它不会(例如在model.eval()
或model.train(mode=False)
时)。
当前的官方文档声明如下:
这仅对某些模块有任何 [原文如此] 影响。 如果它们受到影响,请参阅特定模块的文档以了解它们在训练/评估模式下的行为的详细信息,例如 Dropout、BatchNorm 等。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.