![](/img/trans.png)
[英]Expected object of device type cuda but got device type cpu in Pytorch
[英]Pytorch CPU CUDA device load without gpu
我發現這個很好的代碼 Pytorch mobilenet 我無法在 CPU 上運行。 https://github.com/rdroste/unisal
我是 Pytorch 的新手,所以我不知道該怎么做。
在模塊 train.py 的第 174 行中,設置了設備:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
據我所知,這是正確的 Pytorch。
我也必須更改 torch.load 嗎? 我試過沒有成功。
class BaseModel(nn.Module):
"""Abstract model class with functionality to save and load weights"""
def forward(self, *input):
raise NotImplementedError
def save_weights(self, directory, name):
torch.save(self.state_dict(), directory / f'weights_{name}.pth')
def load_weights(self, directory, name):
self.load_state_dict(torch.load(directory / f'weights_{name}.pth'))
def load_best_weights(self, directory):
self.load_state_dict(torch.load(directory / f'weights_best.pth'))
def load_epoch_checkpoint(self, directory, epoch):
"""Load state_dict from a Trainer checkpoint at a specific epoch"""
chkpnt = torch.load(directory / f"chkpnt_epoch{epoch:04d}.pth")
self.load_state_dict(chkpnt['model_state_dict'])
def load_checkpoint(self, file):
"""Load state_dict from a specific Trainer checkpoint"""
"""Load """
chkpnt = torch.load(file)
self.load_state_dict(chkpnt['model_state_dict'])
def load_last_chkpnt(self, directory):
"""Load state_dict from the last Trainer checkpoint"""
last_chkpnt = sorted(list(directory.glob('chkpnt_epoch*.pth')))[-1]
self.load_checkpoint(last_chkpnt)
我不明白。 我在哪里必須告訴 pytorch 沒有 gpu?
完全錯誤:
Traceback (most recent call last):
File "run.py", line 99, in <module>
fire.Fire()
File "/home/b256/anaconda3/envs/unisal36/lib/python3.6/site-packages/fire/core.py", line 138, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/home/b256/anaconda3/envs/unisal36/lib/python3.6/site-packages/fire/core.py", line 471, in _Fire
target=component.__name__)
File "/home/b256/anaconda3/envs/unisal36/lib/python3.6/site-packages/fire/core.py", line 675, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "run.py", line 95, in predict_examples
example_folder, is_video, train_id=train_id, source=source)
File "run.py", line 72, in predictions_from_folder
folder_path, is_video, source=source, model_domain=model_domain)
File "/home/b256/Data/saliency_models/unisal-master/unisal/train.py", line 871, in generate_predictions_from_path
self.model.load_best_weights(self.train_dir)
File "/home/b256/Data/saliency_models/unisal-master/unisal/train.py", line 1057, in model
self._model = model_cls(**self.model_cfg)
File "/home/b256/Data/saliency_models/unisal-master/unisal/model.py", line 190, in __init__
self.cnn = MobileNetV2(**self.cnn_cfg)
File "/home/b256/Data/saliency_models/unisal-master/unisal/models/MobileNetV2.py", line 156, in __init__
Path(__file__).resolve().parent / 'weights/mobilenet_v2.pth.tar')
File "/home/b256/anaconda3/envs/unisal36/lib/python3.6/site-packages/torch/serialization.py", line 367, in load
return _load(f, map_location, pickle_module)
File "/home/b256/anaconda3/envs/unisal36/lib/python3.6/site-packages/torch/serialization.py", line 538, in _load
result = unpickler.load()
File "/home/b256/anaconda3/envs/unisal36/lib/python3.6/site-packages/torch/serialization.py", line 504, in persistent_load
data_type(size), location)
File "/home/b256/anaconda3/envs/unisal36/lib/python3.6/site-packages/torch/serialization.py", line 113, in default_restore_location
result = fn(storage, location)
File "/home/b256/anaconda3/envs/unisal36/lib/python3.6/site-packages/torch/serialization.py", line 94, in _cuda_deserialize
device = validate_cuda_device(location)
File "/home/b256/anaconda3/envs/unisal36/lib/python3.6/site-packages/torch/serialization.py", line 78, in validate_cuda_device
raise RuntimeError('Attempting to deserialize object on a CUDA '
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location='cpu' to map your storages to the CPU.
在https://pytorch.org/tutorials/beginner/saving_loading_models.html#save-on-gpu-load-on-cpu你會看到有一個map_location
關鍵字參數可以將權重發送到正確的設備:
model.load_state_dict(torch.load(PATH, map_location=device))
來自文檔https://pytorch.org/docs/stable/generated/torch.load.html#torch.load
torch.load() 使用 Python 的 unpickling 工具,但特別處理作為張量基礎的存儲。 它們首先在 CPU 上反序列化,然后移動到保存它們的設備。 如果失敗(例如,因為運行時系統沒有某些設備),則會引發異常。 但是,可以使用 map_location 參數將存儲動態重新映射到一組備用設備。
如果 map_location 是可調用的,它將為每個帶有兩個 arguments 的序列化存儲調用一次:存儲和位置。 storage 參數將是存儲的初始反序列化,駐留在 CPU 上。 每個序列化存儲都有一個與之關聯的位置標簽,用於標識保存它的設備,這個標簽是傳遞給 map_location 的第二個參數。 內置位置標簽是 CPU 張量的“cpu”和 CUDA 張量的“cuda:device_id”(例如“cuda:2”)。 map_location 應該返回 None 或存儲。 如果 map_location 返回一個存儲,它將被用作最終反序列化的 object,已移動到正確的設備。 否則,torch.load() 將回退到默認行為,就像未指定 map_location 一樣。
如果 map_location 是一個 torch.device object 或一個包含設備標簽的字符串,它表示應該加載所有張量的位置。
否則,如果 map_location 是一個字典,它將用於將文件中出現的位置標簽(鍵)重新映射到指定存儲位置的標簽(值)。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.