繁体   English   中英

将 SageMaker 与 Hydra 一起使用

[英]Using SageMaker with Hydra

我有一个关于 SageMaker 和 Hydra 的问题。

TL;DR有没有办法将 arguments 从 SageMaker 估计器传递到 Hydra 脚本? 目前它以非常严格的方式传递参数。

完整问题我使用 Hydra 将配置传递给我的训练脚本。 我有很多配置,它对我很有用。 例如,如果我想使用特定的优化器,我会这样做:

python train.py optimizer=adam

这是我的训练脚本,例如:

@hydra.main(version_base=None, config_path="configs/", config_name="config")
def train(config: DictConfig):
    logging.info(f"Instantiating dataset <{config.dataset._target_}>")
    train_ds, val_ds = hydra.utils.call(config.dataset)

    logging.info(f"Instantiating model <{config.model._target_}>")
    model = hydra.utils.call(config.model)

    logging.info(f"Instantiating optimizer <{config.optimizer._target_}>")
    optimizer = hydra.utils.instantiate(config.optimizer)

    logging.info(f"Instantiating loss <{config.loss._target_}>")
    loss = hydra.utils.instantiate(config.loss)

    callbacks = []
    if "callbacks" in config:
        for _, cb_conf in config.callbacks.items():
            if "_target_" in cb_conf:
                logging.info(f"Instantiating callback <{cb_conf._target_}>")
                callbacks.append(hydra.utils.instantiate(cb_conf))

    metrics = []
    if "metrics" in config:
        for _, metric_conf in config.metrics.items():
            if "_target_" in metric_conf:
                logging.info(f"Instantiating metric <{metric_conf._target_}>")
                metrics.append(hydra.utils.instantiate(metric_conf))

    model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

    model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=config.epochs,
        callbacks=callbacks,
    )


if __name__ == "__main__":
    train()

我有一个相关的optimizer/adam.yaml文件。

现在,我开始使用 SageMaker 在云中运行我的实验,我发现了一个问题。 它不支持 hydra 语法 ( +optimizer=sgd ) 之类的东西。

有没有办法让它与 Hydra 语法很好地配合使用? 如果没有,您有没有建议重构我的训练代码,以便它可以与 Hydra/OmegaConf 很好地配合使用?

我在 SageMaker 问题页面中看到了类似的问题,但没有任何回复: https://github.com/aws/sagemaker-python-sdk/issues/1837

您可以查看将 arguments 作为 ENV 传递并在您的训练脚本中摄取它们吗?

您可以传递包含 ENV 的 dict: https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#estimators

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM