简体   繁体   English

将 Pytorch model.state_dict() 保存到 s3

[英]Saving Pytorch model.state_dict() to s3

I am trying to save a trained Pytorch model to S3.我正在尝试将经过训练的 Pytorch 模型保存到 S3。 However, the torch.save(model.state_dict(), file_name) seems to support only local files.但是, torch.save(model.state_dict(), file_name)似乎只支持本地文件。 How can the state dict be saved to an S3 file?如何将状态字典保存到 S3 文件中?

I'm using Torch 0.4.0我正在使用 Torch 0.4.0

As discussed by Soumith Chintala , Pytorch doesn't have custom APIs to do this job.正如Soumith Chintala所讨论的,Pytorch 没有自定义 API 来完成这项工作。 However you can use boto3 or Petastorm library to solve the problem.但是你可以使用 boto3 或 Petastorm 库来解决这个问题。

Here's a concrete example to write to an S3 object directly:这是一个直接写入 S3 对象的具体示例:

import boto3

# Convert your existing model to JSON
saved_model = model.to_json()

# Write JSON object to S3 as "model.json"
client = boto3.client('s3')
client.put_object(Body=saved_model,
                  Bucket='BUCKET_NAME',
                  Key='model.json')

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM