简体   繁体   中英

Using SageMaker with Hydra

I have a question about SageMaker and Hydra.

TL;DR Is there a way to pass arguments from SageMaker estimator to a Hydra script? Currently it passes parameters in a very strict way.

Full Question I use Hydra in order to pass configs to my training script. I have many configs and it works good for my. For example, if I want to use a specific optimizer, I do:

python train.py optimizer=adam

This is my training script, for instance:

@hydra.main(version_base=None, config_path="configs/", config_name="config")
def train(config: DictConfig):
    logging.info(f"Instantiating dataset <{config.dataset._target_}>")
    train_ds, val_ds = hydra.utils.call(config.dataset)

    logging.info(f"Instantiating model <{config.model._target_}>")
    model = hydra.utils.call(config.model)

    logging.info(f"Instantiating optimizer <{config.optimizer._target_}>")
    optimizer = hydra.utils.instantiate(config.optimizer)

    logging.info(f"Instantiating loss <{config.loss._target_}>")
    loss = hydra.utils.instantiate(config.loss)

    callbacks = []
    if "callbacks" in config:
        for _, cb_conf in config.callbacks.items():
            if "_target_" in cb_conf:
                logging.info(f"Instantiating callback <{cb_conf._target_}>")

    metrics = []
    if "metrics" in config:
        for _, metric_conf in config.metrics.items():
            if "_target_" in metric_conf:
                logging.info(f"Instantiating metric <{metric_conf._target_}>")

    model.compile(optimizer=optimizer, loss=loss, metrics=metrics)


if __name__ == "__main__":

And I have a relevant optimizer/adam.yaml file.

Now, I started using SageMaker to run my experiments in the cloud and I noticed a problem. It doesn't support the hydra syntax ( +optimizer=sgd ), stuff like that.

Is there a way to make it play nicely with Hydra syntax? If not, do you have a suggestion for refactoring my training code so that it would work nicely with Hydra/OmegaConf?

I saw there is a similar question in SageMaker issues page, but it doesn't have any replies: https://github.com/aws/sagemaker-python-sdk/issues/1837

You could look at passing the arguments as ENVs and ingesting them in your training script?

You can pass a dict containing the ENVs: https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#estimators

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

粤ICP备18138465号  © 2020-2024 STACKOOM.COM