简体   繁体   中英

PyTorch DataLoader returns the batch as a list with the batch as the only entry. How is the best way to get a tensor from my DataLoader

I currently have the following situation where I want to use DataLoader to batch a numpy array:

import numpy as np
import torch
import torch.utils.data as data_utils

# Create toy data
x = np.linspace(start=1, stop=10, num=10)
x = np.array([np.random.normal(size=len(x)) for i in range(100)])
print(x.shape)
# >> (100,10)

# Create DataLoader
input_as_tensor = torch.from_numpy(x).float()
dataset = data_utils.TensorDataset(input_as_tensor)
dataloader = data_utils.DataLoader(dataset,
                                   batch_size=100,
                                  )
batch = next(iter(dataloader))

print(type(batch))
# >> <class 'list'>

print(len(batch))
# >> 1

print(type(batch[0]))
# >> class 'torch.Tensor'>

I expect the batch to be already a torch.Tensor . As of now I index the batch like so, batch[0] to get a Tensor but I feel this is not really pretty and makes the code harder to read.

I found that the DataLoader takes a batch processing function called collate_fn . However, setting data_utils.DataLoader(..., collage_fn=lambda batch: batch[0]) only changes the list to a tuple (tensor([ 0.8454, ..., -0.5863]),) where the only entry is the batch as a Tensor.

You would help me a lot by helping me finding out how to elegantly transform the batch to a tensor (even if this would include telling me that indexing the single entry in batch is okay).

Sorry for inconvenience with my answer.

Actually, you don't have to create Dataset from your tensor, you can pass torch.Tensor directly as it implements __getitem__ and __len__ , so this is sufficient:

import numpy as np
import torch
import torch.utils.data as data_utils

# Create toy data
x = np.linspace(start=1, stop=10, num=10)
x = np.array([np.random.normal(size=len(x)) for i in range(100)])

# Create DataLoader
dataset = torch.from_numpy(x).float()
dataloader = data_utils.DataLoader(dataset, batch_size=100)
batch = next(iter(dataloader))

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM