繁体   English   中英

Google Cloud ML引擎scikit学习预测概率'predict_proba()'

[英]Google Cloud ML-engine scikit-learn prediction probability 'predict_proba()'

Google Cloud ML引擎支持部署scikit-learn Pipeline对象的功能。 例如,文本分类Pipeline可能如下所示,

classifier = Pipeline([
('vect', CountVectorizer()), 
('clf', naive_bayes.MultinomialNB())])

可以训练分类器,

classifier.fit(train_x, train_y)

然后可以将分类器上传到Google Cloud Storage,

model = 'model.joblib'
joblib.dump(classifier, model)
model_remote_path = os.path.join('gs://', bucket_name, datetime.datetime.now().strftime('model_%Y%m%d_%H%M%S'), model)
subprocess.check_call(['gsutil', 'cp', model, model_remote_path], stderr=sys.stdout)

然后,可以通过Google Cloud Console或以编程方式'model.joblib'文件链接到Version来创建ModelVersion

然后,该分类器可用于通过调用已部署的模型predict端点来预测新数据,

ml = discovery.build('ml','v1')
project_id = 'projects/{}/models/{}'.format(project_name, model_name)
if version_name is not None:
    project_id += '/versions/{}'.format(version_name)
request_dict = {'instances':['Test data']}
ml_request = ml.projects().predict(name=project_id, body=request_dict).execute()

Google Cloud ML引擎调用分类器的predict函数并返回预测的类。 但是,我希望能够返回置信度分数。 通常,这可以通过调用分类器的predict_proba函数来实现,但是似乎没有选择来更改被调用函数。 我的问题是:使用Google Cloud ML引擎时,是否可以返回scikit学习分类器的置信度分数? 如果没有,您对其他如何获得此结果有什么建议?

更新:我发现了一个hacky解决方案。 它涉及到用自己的predict_proba函数覆盖分类器的predict函数,

nb = naive_bayes.MultinomialNB()
nb.predict = nb.predict_proba
classifier = Pipeline([
('vect', CountVectorizer()), 
('clf', nb)])

出人意料的是,这可行。 如果有人知道更整洁的解决方案,请告诉我。

更新: Google发布了一项名为“ Custom prediction routines的新功能(当前为Beta)。 这使您可以定义在出现预测请求时运行什么代码。它为解决方案添加了更多代码,但肯定没有那么多hacky。

正如您在文档中所看到的那样,您所使用的ML Engine API仅具有预测方法,因此它将仅进行预测(除非您强迫它对您提到的hack进行其他操作)。

如果您想对训练后的模型进行其他操作,则必须加载并正常使用。 如果要使用存储在云存储中的模型,可以执行以下操作:

from google.cloud import storage
from sklearn.externals import joblib

bucket_name = "<BUCKET_NAME>"
gs_model = "path/to/model.joblib"  # path in your Cloud Storage bucket
local_model = "/path/to/model.joblib"  # path in your local machine

client = storage.Client()
bucket = client.get_bucket(bucket_name)
blob = bucket.blob(gs_model)
blob.download_to_filename(local_model)

model = joblib.load(local_model)
model.predict_proba(test_data)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM