简体   繁体   中英

Pytorch model.train() and a separte train() function written in a tutorial

I am new to PyTorch and I was wondering if you could explain to me some of the key differences between the default model.train() in PyTorch and the train() function here.

The other train() function is on the official PyTorch tutorial on text classification and was confused as to whether the model weights are being stored at the end of training.

https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html

learning_rate = 0.005

criterion = nn.NLLLoss()

def train(category_tensor, line_tensor):
    hidden = rnn.initHidden()
    rnn.zero_grad()
    for i in range(line_tensor.size()[0]):
        output, hidden = rnn(line_tensor[i], hidden)
    loss = criterion(output, category_tensor)
    loss.backward()
    # Add parameters' gradients to their values, multiplied by learning rate
    for p in rnn.parameters():
        p.data.add_(-learning_rate, p.grad.data)
    return output, loss.item()

This is the function. This function is then called multiple times in this form:

n_iters = 100000
print_every = 5000
plot_every = 1000
record_every = 500

# Keep track of losses for plotting
current_loss = 0
all_losses = []
predictions = []
true_vals = []

def timeSince(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

start = time.time()

for iter in range(1, n_iters + 1):
    category, line, category_tensor, line_tensor = randomTrainingExample()
    output, loss = train(category_tensor, line_tensor)
    current_loss += loss

    if iter % print_every == 0:
        guess, guess_i = categoryFromOutput(output)
        correct = 'O' if guess == category else 'X (%s)' % category
        print('%d %d%% (%s) %.4f %s / %s %s' % (iter, iter / n_iters * 100, timeSince(start), loss, line, guess, correct))

    if iter % plot_every == 0:
        all_losses.append(current_loss / plot_every)
        current_loss = 0

    if iter % record_every == 0:
        guess, guess_i = categoryFromOutput(output)
        predictions.append(guess)
        true_vals.append(category)

To me it seems that the model weights are not being saved or updated but rather being overridden at each iteration when written like this. Is this correct? Or does the model appear to be training correctly?

Additionally, if I were to use the default function model.train(), what is the main advantage and does model.train() perform more or less the same functionality as the train() function above?

As per the source code here , model.train() sets the module in training mode. So, it basically tells your model that you are training the model. This has any effect only on certain modules like dropout , batchnorm etc. which behave differently in training/evaluation mode. In case of model.train() the model knows it has to learn the layers.

You can call either model.eval() or model.train(mode=False) to tell the model that it has nothing new to learn and the model is used for testing purpose.

model.train() just sets the mode. It doesn't actually train the model.

train() that you are using above is actually training the model, ie, calculating gradient and doing backpropagation to learn the weights.

Learn more about model.train() from official pytorch discussion forum, here .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM