简体   繁体   中英

How do I implement a PyTorch Dataset for use with AWS SageMaker?

I have implemented a PyTorch Dataset that works locally (on my own desktop), but when executed on AWS SageMaker, it breaks. My Dataset implementation is as follows.

class ImageDataset(Dataset):
    def __init__(self, path='./images', transform=None):
        self.path = path
        self.files = [join(path, f) for f in listdir(path) if isfile(join(path, f)) and f.endswith('.jpg')]
        self.transform = transform
        if transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((128, 128)),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])

    def __len__(self):
        return len(files)

    def __getitem__(self, idx):
        img_name = self.files[idx]

        # we may infer the label from the filename
        dash_idx = img_name.rfind('-')
        dot_idx = img_name.rfind('.')
        label = int(img_name[dash_idx + 1:dot_idx])

        image = Image.open(img_name)

        if self.transform:
            image = self.transform(image)

        return image, label

I am following this example and this one too , and I run the estimator as follows.

inputs = {
 'train': 'file://images',
 'eval': 'file://images'
}
estimator = PyTorch(entry_point='pytorch-train.py',
                            role=role,
                            framework_version='1.0.0',
                            train_instance_count=1,
                            train_instance_type=instance_type)
estimator.fit(inputs)

I get the following error.

FileNotFoundError: [Errno 2] No such file or directory: './images'

In the example that I am following, they upload the CFAIR dataset (which is downloaded locally) to S3.

inputs = sagemaker_session.upload_data(path='data', bucket=bucket, key_prefix='data/cifar10')

If I take a peek at inputs , it is just a string literal s3://sagemaker-us-east-3-184838577132/data/cifar10 . The code to create a Dataset and a DataLoader is shown here , which does not help unless I track down the source and step through the logic.

I think what needs to happen inside my ImageDataset is to supply the S3 path and use the AWS CLI or something to query the files and acquire their content. I do not think the AWS CLI is the right approach as this relies on the console and I will have to execute some sub-process commands and then parse through.

There must be a recipe or something to create a custom Dataset backed by S3 files, right?

I was able to create a PyTorch Dataset backed by S3 data using boto3 . Here's the snippet if anyone is interested.

class ImageDataset(Dataset):
    def __init__(self, path='./images', transform=None):
        self.path = path
        self.s3 = boto3.resource('s3')
        self.bucket = self.s3.Bucket(path)
        self.files = [obj.key for obj in self.bucket.objects.all()]
        self.transform = transform
        if transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((128, 128)),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])

    def __len__(self):
        return len(files)

    def __getitem__(self, idx):
        img_name = self.files[idx]

        # we may infer the label from the filename
        dash_idx = img_name.rfind('-')
        dot_idx = img_name.rfind('.')
        label = int(img_name[dash_idx + 1:dot_idx])

        # we need to download the file from S3 to a temporary file locally
        # we need to create the local file name
        obj = self.bucket.Object(img_name)
        tmp = tempfile.NamedTemporaryFile()
        tmp_name = '{}.jpg'.format(tmp.name)

        # now we can actually download from S3 to a local place
        with open(tmp_name, 'wb') as f:
            obj.download_fileobj(f)
            f.flush()
            f.close()
            image = Image.open(tmp_name)

        if self.transform:
            image = self.transform(image)

        return image, label

When you train on a SageMaker remote instance, the SageMaker service starts a new EC2 instance and copies the train/test channel to folders on the EC2 instance local disk, and starts your training script in there.

Therefore, you can find out the location of your train/test/validation channels using os.environ['SM_CHANNEL_TRAIN'] .
Learn more here .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM