简体   繁体   中英

How to output model pickle file to s3 in luigi?

I have a task which trains the model eg:

class ModelTrain(luigi.Task):
    def output(self):
        client = S3Client(os.getenv("CONFIG_AWS_ACCESS_KEY"),
                          os.getenv("CONFIG_AWS_SECRET_KEY"))
        model_output = os.path.join(
            "s3://", _BUCKET, exp.version + '_model.joblib')
        return S3Target(model_output, client) 

    def run(self):
        joblib.dump(model, '/tmp/model.joblib')
    with open(self.output().path, 'wb') as out_file:
        out_file.write(joblib.load('/tmp/model.joblib'))        

FileNotFoundError: [Errno 2] No such file or directory: 's3://bucket/version_model.joblib'

Any pointers in this regard would be helpful

Could you try to remove .path in your open statement.

  def run(self):
    joblib.dump(model, '/tmp/model.joblib')
    with open(self.output(), 'wb') as out_file:
        out_file.write(joblib.load('/tmp/model.joblib'))

A few suggestions-

First, make sure you're using the actual self.output().open() method instead of wrapping open(self.output().path) . This loses the 'atomicity' of the luigi targets, plus those targets are supposed to be swappable, so if you changed back to aa LocalTarget your code should work the same way. You let the specific target class handle what it means to open the file. The error you get looks like python is trying to find a local path, which obviously doesn't work.

Second, I just ran into the same issue, so here's my solution plugged into this code:

from luigi import format

class ModelTrain(luigi.Task):
    def output(self):
        client = S3Client(os.getenv("CONFIG_AWS_ACCESS_KEY"),
                          os.getenv("CONFIG_AWS_SECRET_KEY"))
        model_output = os.path.join(
            "s3://", _BUCKET, exp.version + '_model.joblib')
        # Use luigi.format.Nop for binary files
        return S3Target(model_output, client, format=format.Nop) 

    def run(self):
        # where does `model` come from?
        with self.output().open('w') as s3_f:
            joblib.dump(model, s3_f)

My task is using pickle so I had to follow something similar to this post to re-import.

class MyNextTask(Task):
    ...

    def run(self):
        with my_pickled_task.output().open() as f:
            # The S3Target implements a read method and then I can use
            # the `.loads()` method to import from a binary string
            results = pickle.loads(f.read())

        ... do more stuff with results ...

I recognize this post is stale, but putting the solution I found out there for the next poor soul trying to do this same thing.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM