简体   繁体   English

Pytorch 默认数据加载器因大型图像分类训练集卡住

[英]Pytorch default dataloader gets stuck for large image classification training set

I am training image classification models in Pytorch and using their default data loader to load my training data.我正在 Pytorch 中训练图像分类模型,并使用它们的默认数据加载器来加载我的训练数据。 I have a very large training dataset, so usually a couple thousand sample images per class.我有一个非常大的训练数据集,所以通常每个类有几千个样本图像。 I've trained models with about 200k images total without issues in the past.过去,我已经训练了总共大约 20 万张图像的模型,没有出现任何问题。 However I've found that when have over a million images in total, the Pytorch data loader get stuck.但是我发现当总共有超过一百万张图像时,Pytorch 数据加载器会卡住。

I believe the code is hanging when I call datasets.ImageFolder(...) .我相信当我调用datasets.ImageFolder(...)时代码挂了。 When I Ctrl-C, this is consistently the output:当我 Ctrl-C 时,这始终是输出:

Traceback (most recent call last):                                                                                                 │
  File "main.py", line 412, in <module>                                                                                            │
    main()                                                                                                                         │
  File "main.py", line 122, in main                                                                                                │
    run_training(args.group, args.num_classes)                                                                                     │
  File "main.py", line 203, in run_training                                                                                        │
    train_loader = create_dataloader(traindir, tfm.train_trans, shuffle=True)                                                      │
  File "main.py", line 236, in create_dataloader                                                                                   │
    dataset = datasets.ImageFolder(directory, trans)                                                                               │
  File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 209, in __init__     │
    is_valid_file=is_valid_file)                                                                                                   │
  File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 94, in __init__      │
    samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)                                                     │
  File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 47, in make_dataset  │
    for root, _, fnames in sorted(os.walk(d)):                                                                                     │
  File "/usr/lib/python3.5/os.py", line 380, in walk                                                                               │
    is_dir = entry.is_dir()                                                                                                        │
Keyboard Interrupt                                                                                                                       

I thought there might be a deadlock somewhere, however based off the stack output from Ctrl-C it doesn't look like its waiting on a lock.我认为某处可能存在死锁,但是根据 Ctrl-C 的堆栈输出,它看起来不像是在等待锁定。 So then I thought that the dataloader was just slow because I was trying to load a lot more data.然后我认为数据加载器很慢,因为我试图加载更多数据。 I let it run for about 2 days and it didn't make any progress, and in the last 2 hours of loading I checked the amount of RAM usage stayed the same.我让它运行了大约 2 天,但没有取得任何进展,在加载的最后 2 小时内,我检查了 RAM 使用量保持不变。 I also have been able to load training datasets with over 200k images in less than a couple hours in the past.过去,我还能够在不到几个小时的时间内加载超过 20 万张图像的训练数据集。 I also tried upgrading my GCP machine to have 32 cores, 4 GPUs, and over 100GB in RAM, however it seems to be that after a certain amount of memory is loaded the data loader just gets stuck.我还尝试将我的 GCP 机器升级为 32 个内核、4 个 GPU 和超过 100GB 的 RAM,但是似乎在加载了一定数量的内存后,数据加载器就会卡住。

I'm confused how the data loader could be getting stuck while looping through the directory, and I'm still unsure if its stuck or just extremely slow.我很困惑数据加载器在遍历目录时如何卡住,我仍然不确定它是卡住还是非常慢。 Is there some way I can change the Pytortch dataloader to be able to handle 1million+ images for training?有什么方法可以更改 Pytortch 数据加载器,使其能够处理 100 万张以上的训练图像? Any debugging suggestions are also appreciated!任何调试建议也值得赞赏!

Thank you!谢谢!

It's not a problem with DataLoader , it's a problem with torchvision.datasets.ImageFolder and how it works (and why it works much much worse the more data you have).这不是DataLoader的问题, torchvision.datasets.ImageFolder及其工作方式的问题(以及为什么它在您拥有的数据越torchvision.datasets.ImageFolder效果越差)。

It hangs on this line, as indicated by your error:它挂在这条线上,如您的错误所示:

for root, _, fnames in sorted(os.walk(d)): 

Source can be found here .来源可以在这里找到。

Underlying problem is it keeps each path and corresponding label in giant list , see the code below (a few things removed for brevity):潜在的问题是它将每个path和相应的label保留在巨型list ,请参阅下面的代码(为简洁起见,删除了一些内容):

def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None):
    images = []
    dir = os.path.expanduser(dir)
    # Iterate over all subfolders which were found previously
    for target in sorted(class_to_idx.keys()):
        d = os.path.join(dir, target) # Create path to this subfolder
        # Assuming it is directory (which usually is the case)
        for root, _, fnames in sorted(os.walk(d, followlinks=True)):
            # Iterate over ALL files in this subdirectory
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                # Assuming it is correctly recognized as image file
                item = (path, class_to_idx[target])
                # Add to path with all images
                images.append(item)

    return images

Obviously images will contain 1 million strings (quite lengthy as well) and corresponding int for the classes which definitely is a lot and depends on RAM and CPU.显然,图像将包含 100 万个字符串(也很长)和相应的类的int ,这绝对是很多并且取决于 RAM 和 CPU。

You can create your own datasets though (provided you change names of your images beforehand) so no memory will be occupied by the dataset .您可以创建自己的数据集(前提是您事先更改了图像的名称),因此dataset不会占用内存

Setup data structure设置数据结构

Your folder structure should look like this:您的文件夹结构应如下所示:

root
    class1
    class2
    class3
    ...

Use how many classes you have/need.使用您拥有/需要的课程数量。

Now each class should have the following data:现在每个class都应该有以下数据:

class1
    0.png
    1.png
    2.png
    ...

Given that you can move on to creating datasets.鉴于您可以继续创建数据集。

Create Datasets创建数据集

Below torch.utils.data.Dataset uses PIL to open images, you could do it in another way though:torch.utils.data.Dataset下面使用PIL打开图像,你可以用另一种方式来做:

import os
import pathlib

import torch
from PIL import Image


class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, root: str, folder: str, klass: int, extension: str = "png"):
        self._data = pathlib.Path(root) / folder
        self.klass = klass
        self.extension = extension
        # Only calculate once how many files are in this folder
        # Could be passed as argument if you precalculate it somehow
        # e.g. ls | wc -l on Linux
        self._length = sum(1 for entry in os.listdir(self._data))

    def __len__(self):
        # No need to recalculate this value every time
        return self._length

    def __getitem__(self, index):
        # images always follow [0, n-1], so you access them directly
        return Image.open(self._data / "{}.{}".format(str(index), self.extension))

Now you can create your datasets easily (folder structure assumed like the one above:现在您可以轻松创建您的数据集(假设文件夹结构如上:

root = "/path/to/root/with/images"
dataset = (
    ImageDataset(root, "class0", 0)
    + ImageDataset(root, "class1", 1)
    + ImageDataset(root, "class2", 2)
)

You could add as many datasets with specified classes as you wish, do it in loop or whatever.您可以根据需要添加任意数量的具有指定类的datasets ,在循环中进行或以其他方式进行。

Finally, use torch.utils.data.DataLoader as per usual, eg:最后,像往常一样使用torch.utils.data.DataLoader ,例如:

dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM