[英]torchvision.datasets.mnist RunTimeError on JupyterLab
我正在嘗試在 JupyterLab 上運行以下示例代碼(通過 GCP vertex AI):
import torch
from torchvision import transforms
from torchvision import datasets
train_data = datasets.MNIST(root='data', train=True, download=True, transform=None)
print(train_data)
版本:torch-1.12.1+cu113 torchvision-0.13.1+cu113
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
/tmp/ipykernel_10081/229378695.py in <module>
11 from torchvision import datasets
12
---> 13 train_data = datasets.MNIST(root='data', train=True, download=True, transform=None)
14 print(train_data)
/opt/conda/lib/python3.7/site-packages/torchvision/datasets/mnist.py in __init__(self, root, train, transform, target_transform, download)
102 raise RuntimeError("Dataset not found. You can use download=True to download it")
103
--> 104 self.data, self.targets = self._load_data()
105
106 def _check_legacy_exist(self):
/opt/conda/lib/python3.7/site-packages/torchvision/datasets/mnist.py in _load_data(self)
121 def _load_data(self):
122 image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
--> 123 data = read_image_file(os.path.join(self.raw_folder, image_file))
124
125 label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
/opt/conda/lib/python3.7/site-packages/torchvision/datasets/mnist.py in read_image_file(path)
542
543 def read_image_file(path: str) -> torch.Tensor:
--> 544 x = read_sn3_pascalvincent_tensor(path, strict=False)
545 if x.dtype != torch.uint8:
546 raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
/opt/conda/lib/python3.7/site-packages/torchvision/datasets/mnist.py in read_sn3_pascalvincent_tensor(path, strict)
529
530 assert parsed.shape[0] == np.prod(s) or not strict
--> 531 return parsed.view(*s)
532
533
RuntimeError: shape '[60000, 28, 28]' is invalid for input of size 9437168
____________________
我在嘗試加載 MNIST 時遇到了這個奇怪的錯誤
此錯誤通常是由下載到您系統上的 MNIST 數據集文件問題引起的。 嘗試刪除data
目錄中的 MNIST 數據集文件,然后再次運行代碼以下載數據集文件的新副本。 請遵循以下代碼:
import os
import shutil
mnist_folder = 'data/MNIST'
if os.path.exists(mnist_folder):
shutil.rmtree(mnist_folder)
train_data = datasets.MNIST(root='data', train=True, download=True, transform=None)
如果此方法不起作用,請訪問此網站並將它們放在data/MNIST
文件夾中。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.