简体   繁体   中英

Is it necessary to use with torch.no_grad() for feature extraction?

I'm attempting feature extraction in an unorthodox way. I extract features in eval() mode to switch off the batch norm and dropout layers and use the running means and std provided by ImageNet.

I use a feature extractor to extract features from two related images and concatenate the two tensors stackwise before passing through a linear dense classifier model for training. I'm wondering whether I can avoid using with torch.no_grad() as the two models are unrelated.

Here is a simplified version:

num_classes = 2 
num_epochs = 10
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)

densenet= DenseNetConv()

# set densenet to eval to switch off batch norm and dropout layers and use ImageNet running means/ std devs
densenet.eval() 
densenet.to(device)

classifier = nn.Linear(4416, num_classes)
classifier.to(device)

for epoch in range(num_epochs):
  
  classifier.train()

  for i, (inputs_1, inputs_2,  labels) in enumerate(dataloaders_dict['train']):
       inputs_1= inputs_1.to(device)
       inputs_2 = inputs_2.to(device)
       labels = labels.to(device)

       features_1 = densenet(inputs_1) # extract features 1
       features_2 = densenet(inputs_2) # extract features 2
       
       combined = torch.cat([features_1, features_2], dim=1) # combine features 
       combined = combined(-1, 4416) # reshape 
       
       optimizer.zero_grad()

       # Forward pass to get output/logits
       outputs =  classifier(combined)
           
       # Calculate Loss: softmax --> cross entropy loss
       loss = criterion(outputs, labels)
         
       _, pred = torch.max(outputs, 1)
       equality_check = (labels.data == pred)

       # Getting gradients w.r.t. parameters
       loss.backward()
       optimizer.step()
       

As you can see, I do not call with torch.no_grad() , despite having densenet.eval() as my separate feature extractor. Is there an issue with the way this is implemented or can I assume that this will not interfere with the classifier model?

If you are doing inference on a model, applying torch.no_grad() won't have any effect on the resulting output. As you've said only nn.Module.eval will since it modifies how the forward operation is performed (namely which statistics to use to normalize the batch elements).

It is recommended to switch off gradient computation when backpropagation is not necessary. This avoids caching activations on forward call resulting in faster inference time.

  • In your case, you can either wrap your inference call on densenet with torch.no_grad :

     torch.no_grad(): features_1 = densenet(inputs_1) # extract features 1 features_2 = densenet(inputs_2) # extract features 2
  • Or alternatively, switch off the requires_grad flag on your module's parameter tensors using nn.Module.requires_grad_ :

     densenet.eval() densenet.requires_grad_(False)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM