繁体   English   中英

如何实施 PyTorch 数据集以与 AWS SageMaker 一起使用?

[英]How do I implement a PyTorch Dataset for use with AWS SageMaker?

我已经实现了一个在本地(在我自己的桌面上)工作的 PyTorch Dataset ,但是在 AWS SageMaker 上执行时,它会中断。 我的Dataset实现如下。

class ImageDataset(Dataset):
    def __init__(self, path='./images', transform=None):
        self.path = path
        self.files = [join(path, f) for f in listdir(path) if isfile(join(path, f)) and f.endswith('.jpg')]
        self.transform = transform
        if transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((128, 128)),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])

    def __len__(self):
        return len(files)

    def __getitem__(self, idx):
        img_name = self.files[idx]

        # we may infer the label from the filename
        dash_idx = img_name.rfind('-')
        dot_idx = img_name.rfind('.')
        label = int(img_name[dash_idx + 1:dot_idx])

        image = Image.open(img_name)

        if self.transform:
            image = self.transform(image)

        return image, label

我下面这个例子,这一个太,我跑的estimator如下。

inputs = {
 'train': 'file://images',
 'eval': 'file://images'
}
estimator = PyTorch(entry_point='pytorch-train.py',
                            role=role,
                            framework_version='1.0.0',
                            train_instance_count=1,
                            train_instance_type=instance_type)
estimator.fit(inputs)

我收到以下错误。

FileNotFoundError: [Errno 2] 没有这样的文件或目录:'./images'

在我下面的示例中,他们将 CFAIR 数据集(在本地下载)上传到 S3。

inputs = sagemaker_session.upload_data(path='data', bucket=bucket, key_prefix='data/cifar10')

如果我看一下inputs ,它只是一个字符串文字s3://sagemaker-us-east-3-184838577132/data/cifar10 此处显示了创建DatasetDataLoader的代码,除非我跟踪源并逐步执行逻辑,否则这无济于事。

我认为在我的ImageDataset需要做的是提供S3路径并使用AWS CLI或其他东西来查询文件并获取它们的内容。 我不认为AWS CLI是正确的方法,因为这依赖于控制台,我将不得不执行一些子流程命令然后解析。

必须有配方或其他东西来创建由S3文件支持的自定义Dataset ,对吗?

我能够使用boto3创建由 S3 数据支持的 PyTorch Dataset 如果有人感兴趣,这是片段。

class ImageDataset(Dataset):
    def __init__(self, path='./images', transform=None):
        self.path = path
        self.s3 = boto3.resource('s3')
        self.bucket = self.s3.Bucket(path)
        self.files = [obj.key for obj in self.bucket.objects.all()]
        self.transform = transform
        if transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((128, 128)),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])

    def __len__(self):
        return len(files)

    def __getitem__(self, idx):
        img_name = self.files[idx]

        # we may infer the label from the filename
        dash_idx = img_name.rfind('-')
        dot_idx = img_name.rfind('.')
        label = int(img_name[dash_idx + 1:dot_idx])

        # we need to download the file from S3 to a temporary file locally
        # we need to create the local file name
        obj = self.bucket.Object(img_name)
        tmp = tempfile.NamedTemporaryFile()
        tmp_name = '{}.jpg'.format(tmp.name)

        # now we can actually download from S3 to a local place
        with open(tmp_name, 'wb') as f:
            obj.download_fileobj(f)
            f.flush()
            f.close()
            image = Image.open(tmp_name)

        if self.transform:
            image = self.transform(image)

        return image, label

当您在 SageMaker 远程实例上训练时,SageMaker 服务会启动一个新的 EC2 实例并将训练/测试通道复制到 EC2 实例本地磁盘上的文件夹,然后在那里启动您的训练脚本。

因此,您可以使用os.environ['SM_CHANNEL_TRAIN']找出您的训练/测试/验证通道的位置。
在此处了解更多信息

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM