简体   繁体   English

Pytorch / 加载优化器的 state dict 时的设备问题(cpu,gpu)

[英]Pytorch / device problem(cpu, gpu) when load state dict for optimizer

Hi i`m student who studies pytorch since last summer.嗨,我是从去年夏天开始学习 pytorch 的学生。

state = torch.load('drive/My Drive/MODEL/4 CBAM classifier55')

model = MyResNet()
model.load_state_dict(state['state_dict'])

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=0.0003,betas=(0.5,0.999))
optimizer.load_state_dict(state['optimizer'])

model.to(device)

i wrote code like above.我写了上面的代码。

RuntimeError                              Traceback (most recent call last)
<ipython-input-26-507493db387a> in <module>()
     56     new_loss.backward()
     57 
---> 58     optimizer.step()
     59 
     60     running_loss += loss.item()

/usr/local/lib/python3.6/dist-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     13         def decorate_context(*args, **kwargs):
     14             with self:
---> 15                 return func(*args, **kwargs)
     16         return decorate_context
     17 

/usr/local/lib/python3.6/dist-packages/torch/optim/adam.py in step(self, closure)
     97 
     98                 # Decay the first and second moment running average coefficient
---> 99                 exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
    100                 exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
    101                 if amsgrad:

RuntimeError: expected device cpu but got device cuda:0

And when i implement training code, then i got this kind of error.当我实现训练代码时,我得到了这种错误。 When i comment out 'optimizer.load_state_dict', it works well.当我注释掉“optimizer.load_state_dict”时,它运行良好。 How can i solve this problem?我怎么解决这个问题? Thank you for your answer.谢谢您的回答。 :) :)

Seems like the state was on cuda when you saved and now trying to use it on cpu or vice-versa.似乎state在您保存时位于cuda上,现在尝试在cpu上使用它,反之亦然。 To avoid this error, a simple way is to pass the map_location argument to load.为避免此错误,一种简单的方法是将map_location参数传递给 load。

Just pass map_location=<device you want to use> in torch.load and it should work fine.只需在torch.load中传递map_location=<device you want to use> ,它应该可以正常工作。 Also, see https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-across-devices另请参阅https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-across-devices

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM