简体   繁体   English

在教程中编写的Pytorch model.train()和separte train()函数

[英]Pytorch model.train() and a separte train() function written in a tutorial

I am new to PyTorch and I was wondering if you could explain to me some of the key differences between the default model.train() in PyTorch and the train() function here. 我是PyTorch的新手,我想知道您是否可以向我解释PyTorch中的默认model.train()与此处的train()函数之间的一些关键区别。

The other train() function is on the official PyTorch tutorial on text classification and was confused as to whether the model weights are being stored at the end of training. 另一个train()函数在有关文本分类的官方PyTorch教程上,并且对于在训练结束时是否存储模型权重感到困惑。

https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html

learning_rate = 0.005

criterion = nn.NLLLoss()

def train(category_tensor, line_tensor):
    hidden = rnn.initHidden()
    rnn.zero_grad()
    for i in range(line_tensor.size()[0]):
        output, hidden = rnn(line_tensor[i], hidden)
    loss = criterion(output, category_tensor)
    loss.backward()
    # Add parameters' gradients to their values, multiplied by learning rate
    for p in rnn.parameters():
        p.data.add_(-learning_rate, p.grad.data)
    return output, loss.item()

This is the function. 这就是功能。 This function is then called multiple times in this form: 然后以以下形式多次调用此函数:

n_iters = 100000
print_every = 5000
plot_every = 1000
record_every = 500

# Keep track of losses for plotting
current_loss = 0
all_losses = []
predictions = []
true_vals = []

def timeSince(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

start = time.time()

for iter in range(1, n_iters + 1):
    category, line, category_tensor, line_tensor = randomTrainingExample()
    output, loss = train(category_tensor, line_tensor)
    current_loss += loss

    if iter % print_every == 0:
        guess, guess_i = categoryFromOutput(output)
        correct = 'O' if guess == category else 'X (%s)' % category
        print('%d %d%% (%s) %.4f %s / %s %s' % (iter, iter / n_iters * 100, timeSince(start), loss, line, guess, correct))

    if iter % plot_every == 0:
        all_losses.append(current_loss / plot_every)
        current_loss = 0

    if iter % record_every == 0:
        guess, guess_i = categoryFromOutput(output)
        predictions.append(guess)
        true_vals.append(category)

To me it seems that the model weights are not being saved or updated but rather being overridden at each iteration when written like this. 在我看来,模型权重并没有保存或更新,而是在每次迭代时都被这样写时覆盖。 Is this correct? 这个对吗? Or does the model appear to be training correctly? 还是该模型似乎正在正确训练?

Additionally, if I were to use the default function model.train(), what is the main advantage and does model.train() perform more or less the same functionality as the train() function above? 另外,如果我要使用默认函数model.train(),主要的优点是什么,model.train()执行的功能与上面的train()函数大致相同吗?

As per the source code here , model.train() sets the module in training mode. 根据此处的源代码, model.train()将模块设置为训练模式。 So, it basically tells your model that you are training the model. 因此,它基本上告诉您的模型您正在训练模型。 This has any effect only on certain modules like dropout , batchnorm etc. which behave differently in training/evaluation mode. 这具有仅在某些模块,如任何效果dropoutbatchnorm等,其表现不同的训练/评估模式。 In case of model.train() the model knows it has to learn the layers. 对于model.train() ,模型知道必须学习层。

You can call either model.eval() or model.train(mode=False) to tell the model that it has nothing new to learn and the model is used for testing purpose. 您可以调用model.eval()model.train(mode=False)来告诉模型它没有新知识要学习,并且该模型用于测试目的。

model.train() just sets the mode. model.train()只是设置模式。 It doesn't actually train the model. 它实际上并没有训练模型。

train() that you are using above is actually training the model, ie, calculating gradient and doing backpropagation to learn the weights. 上面使用的train()实际上是训练模型,即计算梯度并进行反向传播以学习权重。

Learn more about model.train() from official pytorch discussion forum, here . 了解更多关于model.train()从官方pytorch论坛, 在这里

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 model.train() 在 PyTorch 中做了什么? - What does model.train() do in PyTorch? 哪些 PyTorch 模块受 model.eval() 和 model.train() 影响? - Which PyTorch modules are affected by model.eval() and model.train()? Pytorch中model.train()和model.eval()模式下BatchNorm层反向传播的区别? - The differences of BatchNorm layer backpropagation at mode of model.train() and model.eval() in Pytorch? tensorlayer Seq2seq模型中的model.train()函数做什么 - What does model.train() function do in tensorlayer Seq2seq model 如何训练写在pytorch嵌套列表中的参数? - How to train parameters that are written in nested list in pytorch? 具有预训练权重的 Model.train() 使结果全部为 0,而 model.eval() 很好 - Model.train() with pre-trained weights makes results all 0 while model.eval() is fine 如何用另一个冻结 model 训练 pytorch model? - How train the pytorch model with another freeze model? 如何获取 gensim 的 doc2vec 的 model.train() 的参数“total_words” - How to obtain a parameter 'total_words' for model.train() of gensim's doc2vec 使用MNIST数据集Pytorch训练SqueezeNet模型 - Train SqueezeNet model using MNIST dataset Pytorch 如何在 Pytorch 中检查 model 是否处于训练或评估模式? - How to check if a model is in train or eval mode in Pytorch?
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM