[英]How to load and predict a pre-trained tensorflow model into Java code?
[英]How can I load a Tensorflow SavedModel in Java with 'predict' Sgnature Def?
我正在訓練Tensorflow Estimator,並使用export_saved_model
將模型保存為SavedModel格式。 現在我想使用Tensorflow Java API加載此模型(我不想使用模型服務器,我需要直接在Java中加載它)。 現在的問題是Estimator.export_saved_model
僅導出“預測” signature_def,而Java中的SavedModelBundle
似乎僅支持具有“ serving_default”簽名def的模型。 所以問題是:有沒有辦法使Estimator.export_saved_model
包含'serving_default'簽名def? 或者是否可以在Java中使用“預測”簽名def加載模型? 還是我可以嘗試其他選擇?
這是導出模型的代碼:
feature_cols = [
tf.feature_column.numeric_column('numeric_feature'),
tf.feature_column.indicator_column( tf.feature_column.categorical_column_with_vocabulary_list('categorial_text_feature', vocabulary_list=['WORD1', 'WORD1']))
]
estimator = tf.estimator.LinearRegressor(
feature_columns=feature_cols,
model_dir=model_dir,
label_dimension=1)
estimator.train(input_fn=input_fn)
serving_input_receiver_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
'numeric_feature': tf.placeholder(tf.float32, shape=(None,)),
'categorial_text_feature': tf.placeholder(tf.string, shape=(None,))
})
estimator.export_saved_model(
export_dir_base=model_dir,
serving_input_receiver_fn=serving_input_receiver_fn)
如果我使用saved_model_cli show --tag_set serve
檢查模型, saved_model_cli show --tag_set serve
得到:
The given SavedModel MetaGraphDef contains SignatureDefs with the following keys:
SignatureDef key: "predict"
並通過saved_model_cli show --tag_set serve --signature_def predict
:
The given SavedModel SignatureDef contains the following input(s):
inputs['numeric_feature'] tensor_info:
dtype: DT_FLOAT
shape: (-1)
name: Placeholder:0
inputs['categorial_text_feature'] tensor_info:
dtype: DT_STRING
shape: (-1)
name: Placeholder_1:0
The given SavedModel SignatureDef contains the following output(s):
outputs['predictions'] tensor_info:
dtype: DT_FLOAT
shape: (-1)
name: linear/linear_model/linear_model/linear_model/weighted_sum:0
Method name is: tensorflow/serving/predict
並通過saved_model_cli show --tag_set serve --signature_def serving_default
:
The given SavedModel SignatureDef contains the following input(s):
The given SavedModel SignatureDef contains the following output(s):
Method name is:
現在,當我在Java中加載模型並輸入一些值時:
SavedModelBundle model = SavedModelBundle.load(modelPath, "serve");
Tensor<?> numericTensor = Tensors.create(new float[] { 10.3 });
Tensor<?> stringTensor = Tensors.create(new byte[][] { "WORD1".getBytes() });
Tensor<?> output = model
.session()
.runner()
.feed("numeric_feature", numericTensor)
.feed("categorial_text_feature", stringTensor)
.fetch("predictions")
.run()
.get(0);
這導致以下錯誤:
java.lang.IllegalArgumentException: No Operation named [numeric_feature] in the Graph
at org.tensorflow.Session$Runner.operationByName(Session.java:372)
at org.tensorflow.Session$Runner.parseOutput(Session.java:381)
at org.tensorflow.Session$Runner.feed(Session.java:131)
找到了一個(不完美,但很簡單)的解決方法:
我剛用as_text=True
導出了模型:
estimator.export_saved_model(
export_dir_base=model_dir,
serving_input_receiver_fn=serving_input_receiver_fn,
as_text=True)
然后只需手動更改.pbtxt文件,以使簽名def稱為“ serving_default”
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.