簡體   English   中英

Pytorch:在 dataloader.dataset 上使用 torch.utils.random_split() 后,數據中缺少批次大小

[英]Pytorch: Batch size is missing in data after torch.utils.random_split() is used on dataloader.dataset

我使用 random_split() 將我的數據分成訓練和測試,我觀察到如果在創建數據加載器后進行隨機拆分,從數據加載器獲取一批數據時會丟失批量大小。

import torch
from torchvision import transforms, datasets
from torch.utils.data import random_split

# Normalize the data
transform_image = transforms.Compose([
  transforms.Resize((240, 320)),
  transforms.ToTensor(),
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

data = '/data/imgs/train'

def load_dataset():
  data_path = data
  main_dataset = datasets.ImageFolder(
    root = data_path,
    transform = transform_image
  )

  loader = torch.utils.data.DataLoader(
    dataset = main_dataset,
    batch_size= 64,
    num_workers = 0,
    shuffle= True
  )

  # Dataset has 22424 data points
  trainloader, testloader = random_split(loader.dataset, [21000, 1424])

  return trainloader, testloader

trainloader, testloader = load_dataset()

現在從訓練和測試加載器中獲取一批圖像:

images, labels = next(iter(trainloader))
images.shape
# %%
len(trainloader)

# %%
images_test, labels_test = next(iter(testloader))
images_test.shape

# %%
len(testloader)

我得到的輸出沒有訓練或測試批次的批次大小。 輸出暗淡應該是 [batch x channel x H x W] 但我得到 [channel x H x W]。

輸出:

在此處輸入圖片說明

但是,如果我從數據集創建拆分,然后使用拆分創建兩個數據加載器,我會在輸出中獲得批量大小。

def load_dataset():
    data_path = data
    main_dataset = datasets.ImageFolder(
      root = data_path,
      transform = transform_image
    )
    # Dataset has 22424 data points
    train_data, test_data = random_split(main_dataset, [21000, 1424])

    trainloader = torch.utils.data.DataLoader(
      dataset = train_data,
      batch_size= 64,
      num_workers = 0,
      shuffle= True
    )

    testloader = torch.utils.data.DataLoader(
      dataset = test_data,
      batch_size= 64,
      num_workers= 0,
      shuffle= True
    )

    return trainloader, testloader

trainloader, testloader = load_dataset()

在運行相同的 4 個命令以獲取單個訓練和測試批次時:

在此處輸入圖片說明

第一種方法錯了嗎? 雖然長度顯示數據已經被拆分了。 那么為什么我看不到批量大小呢?

第一種方法是錯誤的。

只有DataLoader實例返回批次的項目。 像實例一樣的Dataset沒有。

當您調用make_split您將傳遞給它loader.dataset ,它只是對main_dataset (不是DataLoader )的引用。 結果是trainloadertestloaderDataset s 而不是DataLoader s。 實際上,當您從load_dataset返回時,您會丟棄loader ,它是您唯一的DataLoader

第二個版本是你應該怎么做才能獲得兩個單獨的DataLoader

您正在將數據集拆分為兩個。 這將為您提供 2 個數據集,在迭代時將返回形狀channel, height, width (即3,h,w單個圖像張量,並且默認情況下不會為您提供圍繞這些數據集的 Dataloader。
您接下來所做的實際上是下一步正確的操作,即圍繞每個數據集創建一個 Dataloader。 您在 Dataloader 中定義批次大小,現在迭代 Dataloader 將返回形狀為batch_size, channel, height, width張量。

即使您打算提供尺寸為 1 的模型批次,您也必須在張量中具有批次尺寸維度。 為此,您可以使用batchsize=1的 Dataloader 或在開始時為圖像XX.unsqueeze(0)使用torch.unsqueeze(X, 0)添加一個虛擬維度,使張量的形狀為1,3,h,w

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM