[英]data format to predict with model fitted via Sagemaker's XGBoost built-in algorithm and training container
查看以下摘自此處的代碼,我想知道 dtest 是什么格式(抱歉,我無法從帖子中看到這一點):
import pickle as pkl
import tarfile
t = tarfile.open('model.tar.gz', 'r:gz')
t.extractall()
model = pkl.load(open(model_file_path, 'rb'))
# prediction with test data
pred = model.predict(dtest)
在我的例子中,訓練和驗證數據采用 csv 格式,來自 S3 存儲桶:
content_type = "csv"
train_input = TrainingInput("s3://{}/{}/{}/".format(bucket, prefix, 'train'), content_type=content_type)
所以理想情況下,我也想使用相同的格式進行評分/預測/推理。
附言:
這個小 function 似乎工作正常:
def write_prediction_data(data_file_name, target_name, model_file_name, output_file_name):
model = pkl.load(open(model_file_name, 'rb'))
data = pd.read_csv(data_file_name)
target = data[target_name]
data = data.drop([target_name], axis=1)
xgb_data = xgb.DMatrix(data.values, target.values)
data = pd.read_csv(data_file_name)
data['Prediction'] = model.predict(xgb_data)
data.to_csv(output_file_name, index=False)
隨時歡迎改進建議 (-:
“dtest”格式將是 csv,沒有任何 label 列。 除了由 model 正確處理外,它沒有特定的格式要求。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.