簡體   English   中英

在教程中編寫的Pytorch model.train()和separte train()函數

[英]Pytorch model.train() and a separte train() function written in a tutorial

我是PyTorch的新手,我想知道您是否可以向我解釋PyTorch中的默認model.train()與此處的train()函數之間的一些關鍵區別。

另一個train()函數在有關文本分類的官方PyTorch教程上,並且對於在訓練結束時是否存儲模型權重感到困惑。

https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html

learning_rate = 0.005

criterion = nn.NLLLoss()

def train(category_tensor, line_tensor):
    hidden = rnn.initHidden()
    rnn.zero_grad()
    for i in range(line_tensor.size()[0]):
        output, hidden = rnn(line_tensor[i], hidden)
    loss = criterion(output, category_tensor)
    loss.backward()
    # Add parameters' gradients to their values, multiplied by learning rate
    for p in rnn.parameters():
        p.data.add_(-learning_rate, p.grad.data)
    return output, loss.item()

這就是功能。 然后以以下形式多次調用此函數:

n_iters = 100000
print_every = 5000
plot_every = 1000
record_every = 500

# Keep track of losses for plotting
current_loss = 0
all_losses = []
predictions = []
true_vals = []

def timeSince(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

start = time.time()

for iter in range(1, n_iters + 1):
    category, line, category_tensor, line_tensor = randomTrainingExample()
    output, loss = train(category_tensor, line_tensor)
    current_loss += loss

    if iter % print_every == 0:
        guess, guess_i = categoryFromOutput(output)
        correct = 'O' if guess == category else 'X (%s)' % category
        print('%d %d%% (%s) %.4f %s / %s %s' % (iter, iter / n_iters * 100, timeSince(start), loss, line, guess, correct))

    if iter % plot_every == 0:
        all_losses.append(current_loss / plot_every)
        current_loss = 0

    if iter % record_every == 0:
        guess, guess_i = categoryFromOutput(output)
        predictions.append(guess)
        true_vals.append(category)

在我看來,模型權重並沒有保存或更新,而是在每次迭代時都被這樣寫時覆蓋。 這個對嗎? 還是該模型似乎正在正確訓練?

另外,如果我要使用默認函數model.train(),主要的優點是什么,model.train()執行的功能與上面的train()函數大致相同嗎?

根據此處的源代碼, model.train()將模塊設置為訓練模式。 因此,它基本上告訴您的模型您正在訓練模型。 這具有僅在某些模塊,如任何效果dropoutbatchnorm等,其表現不同的訓練/評估模式。 對於model.train() ,模型知道必須學習層。

您可以調用model.eval()model.train(mode=False)來告訴模型它沒有新知識要學習,並且該模型用於測試目的。

model.train()只是設置模式。 它實際上並沒有訓練模型。

上面使用的train()實際上是訓練模型,即計算梯度並進行反向傳播以學習權重。

了解更多關於model.train()從官方pytorch論壇, 在這里

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM