[英]Pytorch model.train() and a separte train() function written in a tutorial
我是PyTorch的新手,我想知道您是否可以向我解释PyTorch中的默认model.train()与此处的train()函数之间的一些关键区别。
另一个train()函数在有关文本分类的官方PyTorch教程上,并且对于在训练结束时是否存储模型权重感到困惑。
https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html
learning_rate = 0.005
criterion = nn.NLLLoss()
def train(category_tensor, line_tensor):
hidden = rnn.initHidden()
rnn.zero_grad()
for i in range(line_tensor.size()[0]):
output, hidden = rnn(line_tensor[i], hidden)
loss = criterion(output, category_tensor)
loss.backward()
# Add parameters' gradients to their values, multiplied by learning rate
for p in rnn.parameters():
p.data.add_(-learning_rate, p.grad.data)
return output, loss.item()
这就是功能。 然后以以下形式多次调用此函数:
n_iters = 100000
print_every = 5000
plot_every = 1000
record_every = 500
# Keep track of losses for plotting
current_loss = 0
all_losses = []
predictions = []
true_vals = []
def timeSince(since):
now = time.time()
s = now - since
m = math.floor(s / 60)
s -= m * 60
return '%dm %ds' % (m, s)
start = time.time()
for iter in range(1, n_iters + 1):
category, line, category_tensor, line_tensor = randomTrainingExample()
output, loss = train(category_tensor, line_tensor)
current_loss += loss
if iter % print_every == 0:
guess, guess_i = categoryFromOutput(output)
correct = 'O' if guess == category else 'X (%s)' % category
print('%d %d%% (%s) %.4f %s / %s %s' % (iter, iter / n_iters * 100, timeSince(start), loss, line, guess, correct))
if iter % plot_every == 0:
all_losses.append(current_loss / plot_every)
current_loss = 0
if iter % record_every == 0:
guess, guess_i = categoryFromOutput(output)
predictions.append(guess)
true_vals.append(category)
在我看来,模型权重并没有保存或更新,而是在每次迭代时都被这样写时覆盖。 这个对吗? 还是该模型似乎正在正确训练?
另外,如果我要使用默认函数model.train(),主要的优点是什么,model.train()执行的功能与上面的train()函数大致相同吗?
根据此处的源代码, model.train()
将模块设置为训练模式。 因此,它基本上告诉您的模型您正在训练模型。 这具有仅在某些模块,如任何效果dropout
, batchnorm
等,其表现不同的训练/评估模式。 对于model.train()
,模型知道必须学习层。
您可以调用model.eval()
或model.train(mode=False)
来告诉模型它没有新知识要学习,并且该模型用于测试目的。
model.train()
只是设置模式。 它实际上并没有训练模型。
上面使用的train()
实际上是训练模型,即计算梯度并进行反向传播以学习权重。
了解更多关于model.train()
从官方pytorch论坛, 在这里 。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.