簡體   English   中英

Pytorch CPU CUDA 設備負載無 gpu

[英]Pytorch CPU CUDA device load without gpu

我發現這個很好的代碼 Pytorch mobilenet 我無法在 CPU 上運行。 https://github.com/rdroste/unisal

我是 Pytorch 的新手,所以我不知道該怎么做。

在模塊 train.py 的第 174 行中,設置了設備:

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

據我所知,這是正確的 Pytorch。

我也必須更改 torch.load 嗎? 我試過沒有成功。

class BaseModel(nn.Module):
    """Abstract model class with functionality to save and load weights"""

    def forward(self, *input):
        raise NotImplementedError

    def save_weights(self, directory, name):
        torch.save(self.state_dict(), directory / f'weights_{name}.pth')

    def load_weights(self, directory, name):
        self.load_state_dict(torch.load(directory / f'weights_{name}.pth'))

    def load_best_weights(self, directory):
        self.load_state_dict(torch.load(directory / f'weights_best.pth'))

    def load_epoch_checkpoint(self, directory, epoch):
        """Load state_dict from a Trainer checkpoint at a specific epoch"""
        chkpnt = torch.load(directory / f"chkpnt_epoch{epoch:04d}.pth")
        self.load_state_dict(chkpnt['model_state_dict'])

    def load_checkpoint(self, file):
        """Load state_dict from a specific Trainer checkpoint"""
        """Load """
        chkpnt = torch.load(file)
        self.load_state_dict(chkpnt['model_state_dict'])

    def load_last_chkpnt(self, directory):
        """Load state_dict from the last Trainer checkpoint"""
        last_chkpnt = sorted(list(directory.glob('chkpnt_epoch*.pth')))[-1]
        self.load_checkpoint(last_chkpnt)

我不明白。 我在哪里必須告訴 pytorch 沒有 gpu?

完全錯誤:

Traceback (most recent call last):
  File "run.py", line 99, in <module>
    fire.Fire()

  File "/home/b256/anaconda3/envs/unisal36/lib/python3.6/site-packages/fire/core.py", line 138, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)

  File "/home/b256/anaconda3/envs/unisal36/lib/python3.6/site-packages/fire/core.py", line 471, in _Fire
    target=component.__name__)

  File "/home/b256/anaconda3/envs/unisal36/lib/python3.6/site-packages/fire/core.py", line 675, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)

  File "run.py", line 95, in predict_examples
    example_folder, is_video, train_id=train_id, source=source)

  File "run.py", line 72, in predictions_from_folder
    folder_path, is_video, source=source, model_domain=model_domain)

  File "/home/b256/Data/saliency_models/unisal-master/unisal/train.py", line 871, in generate_predictions_from_path
    self.model.load_best_weights(self.train_dir)

  File "/home/b256/Data/saliency_models/unisal-master/unisal/train.py", line 1057, in model
    self._model = model_cls(**self.model_cfg)

  File "/home/b256/Data/saliency_models/unisal-master/unisal/model.py", line 190, in __init__
    self.cnn = MobileNetV2(**self.cnn_cfg)

  File "/home/b256/Data/saliency_models/unisal-master/unisal/models/MobileNetV2.py", line 156, in __init__
    Path(__file__).resolve().parent / 'weights/mobilenet_v2.pth.tar')

  File "/home/b256/anaconda3/envs/unisal36/lib/python3.6/site-packages/torch/serialization.py", line 367, in load
    return _load(f, map_location, pickle_module)

  File "/home/b256/anaconda3/envs/unisal36/lib/python3.6/site-packages/torch/serialization.py", line 538, in _load
    result = unpickler.load()

  File "/home/b256/anaconda3/envs/unisal36/lib/python3.6/site-packages/torch/serialization.py", line 504, in persistent_load
    data_type(size), location)

  File "/home/b256/anaconda3/envs/unisal36/lib/python3.6/site-packages/torch/serialization.py", line 113, in default_restore_location
    result = fn(storage, location)

  File "/home/b256/anaconda3/envs/unisal36/lib/python3.6/site-packages/torch/serialization.py", line 94, in _cuda_deserialize
    device = validate_cuda_device(location)

  File "/home/b256/anaconda3/envs/unisal36/lib/python3.6/site-packages/torch/serialization.py", line 78, in validate_cuda_device
    raise RuntimeError('Attempting to deserialize object on a CUDA '
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location='cpu' to map your storages to the CPU.

https://pytorch.org/tutorials/beginner/saving_loading_models.html#save-on-gpu-load-on-cpu你會看到有一個map_location關鍵字參數可以將權重發送到正確的設備:

model.load_state_dict(torch.load(PATH, map_location=device))

來自文檔https://pytorch.org/docs/stable/generated/torch.load.html#torch.load

torch.load() 使用 Python 的 unpickling 工具,但特別處理作為張量基礎的存儲。 它們首先在 CPU 上反序列化,然后移動到保存它們的設備。 如果失敗(例如,因為運行時系統沒有某些設備),則會引發異常。 但是,可以使用 map_location 參數將存儲動態重新映射到一組備用設備。

如果 map_location 是可調用的,它將為每個帶有兩個 arguments 的序列化存儲調用一次:存儲和位置。 storage 參數將是存儲的初始反序列化,駐留在 CPU 上。 每個序列化存儲都有一個與之關聯的位置標簽,用於標識保存它的設備,這個標簽是傳遞給 map_location 的第二個參數。 內置位置標簽是 CPU 張量的“cpu”和 CUDA 張量的“cuda:device_id”(例如“cuda:2”)。 map_location 應該返回 None 或存儲。 如果 map_location 返回一個存儲,它將被用作最終反序列化的 object,已移動到正確的設備。 否則,torch.load() 將回退到默認行為,就像未指定 map_location 一樣。

如果 map_location 是一個 torch.device object 或一個包含設備標簽的字符串,它表示應該加載所有張量的位置。

否則,如果 map_location 是一個字典,它將用於將文件中出現的位置標簽(鍵)重新映射到指定存儲位置的標簽(值)。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM