簡體   English   中英

amazon sagemaker 自定義代碼增量訓練

[英]incremental training on custom code in amazon sagemaker

我正在amazon sagemaker中邁出第一步。 我正在使用腳本模式來訓練分類算法。 訓練很好,但是我無法進行增量訓練。 我想用新數據再次訓練同一個模型。 這是我所做的。 這是我的腳本:

import sagemaker
from sagemaker.tensorflow import TensorFlow
from sagemaker import get_execution_role

bucket = 'sagemaker-blablabla'
train_data = 's3://{}/{}'.format(bucket,'train')
validation_data = 's3://{}/{}'.format(bucket,'test')

s3_output_location = 's3://{}'.format(bucket)

tf_estimator = TensorFlow(entry_point='main.py', 
                          role=get_execution_role(),
                          train_instance_count=1, 
                          train_instance_type='ml.p2.xlarge',
                          framework_version='1.12', 
                          py_version='py3',
                          output_path=s3_output_location)

inputs = {'train': train_data, 'test': validation_data}
tf_estimator.fit(inputs)

入口點是我的自定義 keras 代碼,我對其進行了調整以接收來自腳本的參數。 現在培訓已成功完成,我的 s3 存儲桶中有 model.tar.gz。 我想再次訓練,但我不清楚該怎么做。我試過了

trained_model = 's3://sagemaker-blablabla/sagemaker-tensorflow-scriptmode-2019-11-27-12-01-42-300/output/model.tar.gz'

tf_estimator = sagemaker.estimator.Estimator(image_name='blablabla-west-1.amazonaws.com/sagemaker-tensorflow-scriptmode:1.12-gpu-py3', 
                                              role=get_execution_role(),
                                              train_instance_count=1, 
                                              train_instance_type='ml.p2.xlarge',
                                              output_path=s3_output_location,
                                              model_uri = trained_model)

inputs = {'train': train_data, 'test': validation_data}

tf_estimator.fit(inputs)

不起作用。 首先,我不知道如何檢索訓練圖像名稱(為此我在aws控制台中查找它,但我想應該有一個更聰明的解決方案),其次這段代碼拋出一個關於入口點的異常但是它是我的理解是,當我使用現成的圖像進行增量學習時,我不需要它。 我肯定錯過了一些重要的東西,有什么幫助嗎? 謝謝!

增量訓練是內置圖像分類器和對象檢測器的原生功能。 對於自定義代碼,開發人員有責任編寫增量訓練邏輯並驗證其有效性。 這是一個可能的路徑:

  1. 使用fit中傳遞的數據通道之一加載模型狀態(工件微調)
  2. 在您的代碼中,檢查模型狀態通道是否充滿了工件。 如果是,則從該狀態實例化一個模型並繼續訓練。 這是特定於框架的,您可以采取必要的預防措施以避免忘記以前的知識。

一些框架為增量學習提供了比其他框架更好的支持。 例如,某些 sklearn 模型提供了incremental_fit方法。 對於 DL 框架,從檢查點繼續訓練在技術上非常容易,但如果新數據與以前看到的數據有很大不同,這可能會導致您的模型忘記以前的學習。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM