繁体   English   中英

PyTorch DataLoader 将批次作为列表返回,批次作为唯一条目。 如何从我的 DataLoader 获取张量的最佳方式

[英]PyTorch DataLoader returns the batch as a list with the batch as the only entry. How is the best way to get a tensor from my DataLoader

我目前有以下情况,我想使用DataLoader批处理 numpy 数组:

import numpy as np
import torch
import torch.utils.data as data_utils

# Create toy data
x = np.linspace(start=1, stop=10, num=10)
x = np.array([np.random.normal(size=len(x)) for i in range(100)])
print(x.shape)
# >> (100,10)

# Create DataLoader
input_as_tensor = torch.from_numpy(x).float()
dataset = data_utils.TensorDataset(input_as_tensor)
dataloader = data_utils.DataLoader(dataset,
                                   batch_size=100,
                                  )
batch = next(iter(dataloader))

print(type(batch))
# >> <class 'list'>

print(len(batch))
# >> 1

print(type(batch[0]))
# >> class 'torch.Tensor'>

我希望batch已经是torch.Tensor 到目前为止,我像这样索引批处理, batch[0]以获得张量,但我觉得这不是很漂亮,并且使代码更难阅读。

我发现DataLoader需要一个名为 collate_fn 的批处理collate_fn 但是,设置data_utils.DataLoader(..., collage_fn=lambda batch: batch[0])只会将列表更改为元组(tensor([ 0.8454, ..., -0.5863]),) ,其中唯一的条目是批处理作为张量。

你会帮助我找出如何优雅地将批处理转换为张量(即使这包括告诉我批量索引单个条目是可以的),这对我有很大帮助。

很抱歉给我的回答带来不便。

实际上,您不必从张量创建Dataset ,您可以直接传递torch.Tensor ,因为它实现了__getitem____len__ ,所以这就足够了:

import numpy as np
import torch
import torch.utils.data as data_utils

# Create toy data
x = np.linspace(start=1, stop=10, num=10)
x = np.array([np.random.normal(size=len(x)) for i in range(100)])

# Create DataLoader
dataset = torch.from_numpy(x).float()
dataloader = data_utils.DataLoader(dataset, batch_size=100)
batch = next(iter(dataloader))

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM