简体   繁体   中英

Can I pre-transform images when training a deep learning model?

If you ever trained a model, you sometimes can see DataLoader (eg I am using pytorch Dataloader) could be the bottleneck because everytime you get an training sample from the dataset, the data transformation will be performed on-the-fly. Here I take the getitem func of DatasetFolder from torchvision.datasets as an example.

    path, target = self.samples[index]
    sample = self.loader(path)
    if self.transform is not None:
        sample = self.transform(sample)
    if self.target_transform is not None:
        target = self.target_transform(target)

    return sample, target

I wonder can we pre-processed the images (eg ImageNet) in advance to tensors and save to disk. Then we modify the __getitem__ function to get these tensors directly from disk. How efficient is this approach? Anyone has tried this solution before?

I think that maybe the loading from disk will burden and likely become a new bottleneck (instead of data transform we have before). Another thing is the size, for example, one ImageNet image takes 74 MB when being saved as tensors using standard transformation:

transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

Two approaches:

1)You can cache the Dataset if you have enough memory and use it later. The limitation is you are limited to 0 workers in the first epoch. Something like this.

class ImageNetDataset(Dataset):
    def __init__(self, use_cache=False):
        self.cached_data = []
        self.use_cache = use_cache
        
    def __getitem__(self, index):
        if not self.use_cache:
            x = self.data[index] 
            self.cached_data.append(x)
        else:
            x = self.cached_data[index]
        return x
    
    def set_use_cache(self, use_cache):
        if use_cache:
            self.cached_data = torch.stack(self.cached_data)
        else:
            self.cached_data = []
        self.use_cache = use_cache

2)Transform the dataset beforehand and save it to disk. You need to write a small code for that separately. And use this folder during training.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM