简体   繁体   中英

Pytorch/torchvision - modify images and labels of a Dataset object

So I have this line of code to load a dataset of images from two classes called "0" and "1" for simplicity:

train_data = torchvision.datasets.ImageFolder(os.path.join(TRAIN_DATA_DIR), train_transform)

and then I prepare the loader to be used with my model in this way:

train_loader = torch.utils.data.DataLoader(train_data, TRAIN_BATCH_SIZE, shuffle=True)

So for now each image is associated to a class, what I want to do is take each image and apply a transformation to it between those two lines of code, let's say a rotation of one of four degrees: 0, 90, 180, 270, and add that info as an additional label of four classes: 0, 1, 2, 3. In the end I want the dataset to contain the rotated images and as their labels a list of two values: the class of the image and the applied rotation.

I tried this and there is no error, but the dataset remains unchanged if then I try to print the labels:

for idx,label in enumerate(train_data.targets):
    train_data.targets[idx] = [label, 1]

Is there a nice way to do it by modifying directly train_data without requiring a custom dataset?

Is there a nice way to do it by modifying directly train_data without requiring a custom dataset?

No, there isn't. If you want to use datasets.ImageFolder , you have to accept its limited flexibility. In fact, ImageFolder is simply a subclass of DatasetFolder , which is pretty much what a custom dataset is. You can see in its source code the following section of __getItem__ :

if self.transform is not None:
    sample = self.transform(sample)
if self.target_transform is not None:
    target = self.target_transform(target)

This makes what you want impossible since your expected transform should modify both the image and the target at the same time, which is done independently here.

So, start by making your subclass of Dataset similar to DatasetFolder , and simply implement your own transform which takes in an image and a target at the same time and returns their transformed values. This is just an example of a transform class you could have, which would then need to be composed into a single function call:

class RotateTransform(object):
    def __call__(self, image, target):
        # Rotate the image randomly and adjust the target accordingly
        #...

        return image, target

If that's too much trouble for your case, then the best option you have is what @jchaykow mentionned, which is to simply modify your files prior to running your code.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM