簡體   English   中英

通過 Sagemaker 的 XGBoost 內置算法和訓練容器擬合 model 進行預測的數據格式

[英]data format to predict with model fitted via Sagemaker's XGBoost built-in algorithm and training container

查看以下摘自此處的代碼,我想知道 dtest 是什么格式(抱歉,我無法從帖子中看到這一點):

import pickle as pkl 
import tarfile

t = tarfile.open('model.tar.gz', 'r:gz')
t.extractall()

model = pkl.load(open(model_file_path, 'rb'))

# prediction with test data
pred = model.predict(dtest)

在我的例子中,訓練和驗證數據采用 csv 格式,來自 S3 存儲桶:

content_type = "csv"
train_input = TrainingInput("s3://{}/{}/{}/".format(bucket, prefix, 'train'), content_type=content_type)

所以理想情況下,我也想使用相同的格式進行評分/預測/推理。

附言:

這個小 function 似乎工作正常:

def write_prediction_data(data_file_name, target_name, model_file_name, output_file_name):

    model = pkl.load(open(model_file_name, 'rb'))
    data = pd.read_csv(data_file_name) 
    target = data[target_name]
    data = data.drop([target_name], axis=1)
    xgb_data = xgb.DMatrix(data.values, target.values)

    data = pd.read_csv(data_file_name)
    data['Prediction'] = model.predict(xgb_data)

    data.to_csv(output_file_name, index=False)

隨時歡迎改進建議 (-:

“dtest”格式將是 csv,沒有任何 label 列。 除了由 model 正確處理外,它沒有特定的格式要求。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM