简体   繁体   中英

What is the purpose of with torch.no_grad():

Consider the following code for Linear Regression implemented using PyTorch:

X is the input, Y is the output for the training set, w is the parameter that needs to be optimised
import torch

X = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
Y = torch.tensor([2, 4, 6, 8], dtype=torch.float32)

w = torch.tensor(0.0, dtype=torch.float32, requires_grad=True)

def forward(x):
    return w * x

def loss(y, y_pred):
    return ((y_pred - y)**2).mean()

print(f'Prediction before training: f(5) = {forward(5).item():.3f}')

learning_rate = 0.01
n_iters = 100

for epoch in range(n_iters):
    # predict = forward pass
    y_pred = forward(X)

    # loss
    l = loss(Y, y_pred)

    # calculate gradients = backward pass
    l.backward()

    # update weights
    #w.data = w.data - learning_rate * w.grad
    with torch.no_grad():
        w -= learning_rate * w.grad
    
    # zero the gradients after updating
    w.grad.zero_()

    if epoch % 10 == 0:
        print(f'epoch {epoch+1}: w = {w.item():.3f}, loss = {l.item():.8f}')

What does the 'with' block do? The requires_grad argument for w is already set to True. Why is it then being put under a with torch.no_grad() block?

There is no reason to track gradients when updating the weights; that is why you will find a decorator (@torch.no_grad()) for the step method in any implementation of an optimizer.

"With torch.no_grad" block means doing these lines without keeping track of the gradients.

The requires_grad argument tells PyTorch that we want to be able to calculate the gradients for those values. However, the with torch.no_grad() tells PyTorch to not calculate the gradients, and the program explicitly uses it here (as with most neural networks) in order to not update the gradients when it is updating the weights as that would affect the back propagation.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM