繁体   English   中英

如何从pytorch dataloader获取批量迭代的总数?

[英]How to get the total number of batch iteration from pytorch dataloader?

我有一个问题,如何从 pytorch 数据加载器获取批量迭代的总数?

以下是训练的常用代码

for i, batch in enumerate(dataloader):

那么,有没有什么方法可以获取“for循环”的总迭代次数?

在我的 NLP 问题中,总迭代次数与 int(n_train_samples/batch_size)...

例如,如果我只截断训练数据 10,000 个样本并将批大小设置为 1024,那么在我的 NLP 问题中会发生 363 次迭代。

我想知道如何获得“for 循环”中的总迭代次数。

谢谢你。

len(dataloader)返回批次总数。 这取决于数据集的__len__函数,因此请确保其设置正确。

创建数据加载器时还有一个附加参数。 它被称为drop_last

如果drop_last=True则长度为number_of_training_examples // batch_size 如果drop_last=False它可能是number_of_training_examples // batch_size +1

BS=128
ds_train = torchvision.datasets.CIFAR10('/data/cifar10', download=True, train=True, transform=t_train)
dl_train = DataLoader( ds_train, batch_size=BS, drop_last=True, shuffle=True)

对于预定义的数据集,您可能会获得如下示例的数量:

# number of examples
len(dl_train.dataset) 

数据加载器中正确的批次数始终为:

# number of batches
len(dl_train) 

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM