I have a question that How to get the total number of batch iteration from pytorch dataloader?
The following is a common code for training
for i, batch in enumerate(dataloader):
Then, is there any method to get the total number of iteration for the "for loop"?
In my NLP problem, the total number of iteration is different from int(n_train_samples/batch_size)...
For example, if I truncate train data only 10,000 samples and set the batch size as 1024, then 363 iteration occurs in my NLP problem.
I wonder how to get the number of total iteration in "the for-loop".
Thank you.
len(dataloader)
returns the total number of batches. It depends on the __len__
function of your dataset, so make sure it is set correctly.
There is one additional parameter when creating the dataloader. It is called drop_last
.
If drop_last=True
then length is number_of_training_examples // batch_size
. If drop_last=False
it may be number_of_training_examples // batch_size +1
.
BS=128
ds_train = torchvision.datasets.CIFAR10('/data/cifar10', download=True, train=True, transform=t_train)
dl_train = DataLoader( ds_train, batch_size=BS, drop_last=True, shuffle=True)
For predefined datasets you may get the number of examples like:
# number of examples
len(dl_train.dataset)
The correct number of batches inside dataloader is always:
# number of batches
len(dl_train)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.