简体   繁体   中英

How to get the total number of batch iteration from pytorch dataloader?

I have a question that How to get the total number of batch iteration from pytorch dataloader?

The following is a common code for training

for i, batch in enumerate(dataloader):

Then, is there any method to get the total number of iteration for the "for loop"?

In my NLP problem, the total number of iteration is different from int(n_train_samples/batch_size)...

For example, if I truncate train data only 10,000 samples and set the batch size as 1024, then 363 iteration occurs in my NLP problem.

I wonder how to get the number of total iteration in "the for-loop".

Thank you.

len(dataloader) returns the total number of batches. It depends on the __len__ function of your dataset, so make sure it is set correctly.

There is one additional parameter when creating the dataloader. It is called drop_last .

If drop_last=True then length is number_of_training_examples // batch_size . If drop_last=False it may be number_of_training_examples // batch_size +1 .

BS=128
ds_train = torchvision.datasets.CIFAR10('/data/cifar10', download=True, train=True, transform=t_train)
dl_train = DataLoader( ds_train, batch_size=BS, drop_last=True, shuffle=True)

For predefined datasets you may get the number of examples like:

# number of examples
len(dl_train.dataset) 

The correct number of batches inside dataloader is always:

# number of batches
len(dl_train) 

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM