简体   繁体   中英

Can't fix: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

Currently, I'm trying to replicate a DeblurGanV2 network. At the moment, I'm working on performing the training. Here is my current status of my trainings pipeline:

from torch.autograd import Variable
torch.autograd.set_detect_anomaly(mode=True)
total_generator_loss = 0
total_discriminator_loss = 0
psnr_score = 0.0
used_loss_function = 'wgan_gp_loss'
for epoch in range(n_epochs):

      #set to train mode
      generator.train(); discriminator.train()
      tqdm_bar = tqdm(train_loader, desc=f'Training Epoch {epoch} ', total=int(len(train_loader)))
      for batch_idx, imgs in enumerate(tqdm_bar):
        
        #load imgs to cpu
        blurred_images = imgs["blurred"].cuda()
        sharped_images = imgs["sharp"].cuda()
        
        # generator output
        deblurred_img = generator(blurred_images)
    
        # denormalize
        with torch.no_grad():
          denormalized_blurred = denormalize(blurred_images)
          denormalized_sharp = denormalize(sharped_images)
          denormalized_deblurred = denormalize(deblurred_img)
    
        # get D's output
        sharp_discriminator_out = discriminator(sharped_images)
        deblurred_discriminator_out = discriminator(deblurred_img)
    
        # set critic_updates
        if used_loss_function== 'wgan_gp_loss':
          critic_updates = 5
        else:
            critic_updates = 1
    
        #train discriminator
        discriminator_loss = 0
        for i in range(critic_updates):
          discriminator_optimizer.zero_grad()
          # train discriminator on real and fake
          if used_loss_function== 'wgan_gp_loss':
            gp_lambda = 10
            alpha = random.random()
            interpolates = alpha * sharped_images + (1 - alpha) * deblurred_img
            interpolates_discriminator_out = discriminator(interpolates)
            kwargs = {'gp_lambda': gp_lambda,
                       'interpolates': interpolates,
                       'interpolates_discriminator_out': interpolates_discriminator_out,
                       'sharp_discriminator_out': sharp_discriminator_out,
                       'deblurred_discriminator_out': deblurred_discriminator_out
                        }
            wgan_loss_d, gp_d = wgan_gp_loss('D', **kwargs)
            discriminator_loss_per_update = wgan_loss_d + gp_d
    
          discriminator_loss_per_update.backward(retain_graph=True)
          discriminator_optimizer.step()
          discriminator_loss += discriminator_loss_per_update.item()

But when I run this code, I receive the following error message:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 512, 4, 4]] is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

RuntimeError Traceback (most recent call last) in () 62 # discriminator_loss_per_update = gan_loss_d 63 —> 64 discriminator_loss_per_update.backward(retain_graph=True) 65 discriminator_optimizer.step() 66 discriminator_loss += discriminator_loss_per_update.item()

1 frames /usr/local/lib/python3.7/dist-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph, inputs) 243 create_graph=create_graph, 244 inputs=inputs) → 245 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs) 246 247 def register_hook(self, hook):

/usr/local/lib/python3.7/dist-packages/torch/autograd/init.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs) 145 Variable.execution_engine.run_backward( 146 tensors, grad_tensors, retain_graph, create_graph, inputs, → 147 allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag 148 149

Unfortunately, I can't really trace the in-place operation that would cause this error. Does anybody maybe has an idea or advice for me? I would appreciate any input:slight_smile:

try changing the last line to:

discriminator_loss = discriminator_loss + discriminator_loss_per_update.item()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM