![](/img/trans.png)
[英]Error with _DataLoaderIter in torch.utils.data.dataloader
[英]Pytorch: Batch size is missing in data after torch.utils.random_split() is used on dataloader.dataset
我使用 random_split() 將我的數據分成訓練和測試,我觀察到如果在創建數據加載器后進行隨機拆分,從數據加載器獲取一批數據時會丟失批量大小。
import torch
from torchvision import transforms, datasets
from torch.utils.data import random_split
# Normalize the data
transform_image = transforms.Compose([
transforms.Resize((240, 320)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
data = '/data/imgs/train'
def load_dataset():
data_path = data
main_dataset = datasets.ImageFolder(
root = data_path,
transform = transform_image
)
loader = torch.utils.data.DataLoader(
dataset = main_dataset,
batch_size= 64,
num_workers = 0,
shuffle= True
)
# Dataset has 22424 data points
trainloader, testloader = random_split(loader.dataset, [21000, 1424])
return trainloader, testloader
trainloader, testloader = load_dataset()
現在從訓練和測試加載器中獲取一批圖像:
images, labels = next(iter(trainloader))
images.shape
# %%
len(trainloader)
# %%
images_test, labels_test = next(iter(testloader))
images_test.shape
# %%
len(testloader)
我得到的輸出沒有訓練或測試批次的批次大小。 輸出暗淡應該是 [batch x channel x H x W] 但我得到 [channel x H x W]。
輸出:
但是,如果我從數據集創建拆分,然后使用拆分創建兩個數據加載器,我會在輸出中獲得批量大小。
def load_dataset():
data_path = data
main_dataset = datasets.ImageFolder(
root = data_path,
transform = transform_image
)
# Dataset has 22424 data points
train_data, test_data = random_split(main_dataset, [21000, 1424])
trainloader = torch.utils.data.DataLoader(
dataset = train_data,
batch_size= 64,
num_workers = 0,
shuffle= True
)
testloader = torch.utils.data.DataLoader(
dataset = test_data,
batch_size= 64,
num_workers= 0,
shuffle= True
)
return trainloader, testloader
trainloader, testloader = load_dataset()
在運行相同的 4 個命令以獲取單個訓練和測試批次時:
第一種方法錯了嗎? 雖然長度顯示數據已經被拆分了。 那么為什么我看不到批量大小呢?
第一種方法是錯誤的。
只有DataLoader
實例返回批次的項目。 像實例一樣的Dataset
沒有。
當您調用make_split
您將傳遞給它loader.dataset
,它只是對main_dataset
(不是DataLoader
)的引用。 結果是trainloader
和testloader
是Dataset
s 而不是DataLoader
s。 實際上,當您從load_dataset
返回時,您會丟棄loader
,它是您唯一的DataLoader
。
第二個版本是你應該怎么做才能獲得兩個單獨的DataLoader
。
您正在將數據集拆分為兩個。 這將為您提供 2 個數據集,在迭代時將返回形狀channel, height, width
(即3,h,w
單個圖像張量,並且默認情況下不會為您提供圍繞這些數據集的 Dataloader。
您接下來所做的實際上是下一步正確的操作,即圍繞每個數據集創建一個 Dataloader。 您在 Dataloader 中定義批次大小,現在迭代 Dataloader 將返回形狀為batch_size, channel, height, width
張量。
即使您打算提供尺寸為 1 的模型批次,您也必須在張量中具有批次尺寸維度。 為此,您可以使用batchsize=1
的 Dataloader 或在開始時為圖像X
或X.unsqueeze(0)
使用torch.unsqueeze(X, 0)
添加一個虛擬維度,使張量的形狀為1,3,h,w
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.