简体   繁体   中英

Iterating over subsets from torch.utils.data.random_split

I am currently loading a folder with AI training data in it. The subfolders represent the label names with the corresponding images inside. This works well by using pyTorch's ImageFolder loader.

def load_dataset():
    data_path = 'C:/example_folder/'

    train_dataset_manual = torchvision.datasets.ImageFolder(
        root=data_path,
        transform=torchvision.transforms.ToTensor()
    )

    train_loader_manual = torch.utils.data.DataLoader(
        train_dataset_manual,
        batch_size=1,
        num_workers=0,
        shuffle=True
    )

    return train_loader_manual

full_dataset = load_dataset()

Now I want to have this dataset split into a training and a test data set. I am using the random_split function for this:

training_data_size = 0.8

train_size = int(training_data_size * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

The full_dataset is an object of type torch.utils.data.dataloader.DataLoader . I can iterate through it with a loop like this:

for batch_idx, (data, target) in enumerate(full_dataset):
    print(batch_idx)

The train_dataset is an object of type torch.utils.data.dataset.Subset . If I try to loop through it, I get:

TypeError 'DataLoader' object is not subscriptable:

for batch_idx, (data, target) in enumerate(train_dataset):
    print(batch_idx)

How can I loop through it? I am relatively new to Python.

Thanks!

You need to apply random_split to a Dataset not a DataLoader . The dataset used to define the DataLoader is available in the DataLoader.dataset member.

For example you could do

train_dataset, test_dataset = torch.utils.data.random_split(full_dataset.dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=1, num_workers=0, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, num_workers=0, shuffle=False)

Then you can iterate over train_loader and test_loader as expected.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM