簡體   English   中英

如何使用'predict'Sgnature Def在Java中加載Tensorflow SavedModel?

[英]How can I load a Tensorflow SavedModel in Java with 'predict' Sgnature Def?

我正在訓練Tensorflow Estimator,並使用export_saved_model將模型保存為SavedModel格式。 現在我想使用Tensorflow Java API加載此模型(我不想使用模型服務器,我需要直接在Java中加載它)。 現在的問題是Estimator.export_saved_model僅導出“預測” signature_def,而Java中的SavedModelBundle似乎僅支持具有“ serving_default”簽名def的模型。 所以問題是:有沒有辦法使Estimator.export_saved_model包含'serving_default'簽名def? 或者是否可以在Java中使用“預測”簽名def加載模型? 還是我可以嘗試其他選擇?

這是導出模型的代碼:

feature_cols = [
        tf.feature_column.numeric_column('numeric_feature'),
        tf.feature_column.indicator_column( tf.feature_column.categorical_column_with_vocabulary_list('categorial_text_feature', vocabulary_list=['WORD1', 'WORD1']))
]

estimator = tf.estimator.LinearRegressor(
    feature_columns=feature_cols,
    model_dir=model_dir,
    label_dimension=1)

    estimator.train(input_fn=input_fn)

serving_input_receiver_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
        'numeric_feature': tf.placeholder(tf.float32, shape=(None,)),
        'categorial_text_feature': tf.placeholder(tf.string, shape=(None,))
})
estimator.export_saved_model(
    export_dir_base=model_dir,
    serving_input_receiver_fn=serving_input_receiver_fn)

如果我使用saved_model_cli show --tag_set serve檢查模型, saved_model_cli show --tag_set serve得到:

The given SavedModel MetaGraphDef contains SignatureDefs with the following keys:
SignatureDef key: "predict"

並通過saved_model_cli show --tag_set serve --signature_def predict

The given SavedModel SignatureDef contains the following input(s):
  inputs['numeric_feature'] tensor_info:
      dtype: DT_FLOAT
      shape: (-1)
      name: Placeholder:0
  inputs['categorial_text_feature'] tensor_info:
      dtype: DT_STRING
      shape: (-1)
      name: Placeholder_1:0
The given SavedModel SignatureDef contains the following output(s):
  outputs['predictions'] tensor_info:
      dtype: DT_FLOAT
      shape: (-1)
      name: linear/linear_model/linear_model/linear_model/weighted_sum:0
Method name is: tensorflow/serving/predict

並通過saved_model_cli show --tag_set serve --signature_def serving_default

The given SavedModel SignatureDef contains the following input(s):
The given SavedModel SignatureDef contains the following output(s):
Method name is: 

現在,當我在Java中加載模型並輸入一些值時:

SavedModelBundle model = SavedModelBundle.load(modelPath, "serve");

Tensor<?> numericTensor = Tensors.create(new float[] { 10.3 });
Tensor<?> stringTensor = Tensors.create(new byte[][] { "WORD1".getBytes() });
Tensor<?> output = model
                .session()
                .runner()
                .feed("numeric_feature", numericTensor)
                .feed("categorial_text_feature", stringTensor)
                .fetch("predictions")
                .run()
                .get(0);

這導致以下錯誤:

java.lang.IllegalArgumentException: No Operation named [numeric_feature] in the Graph
    at org.tensorflow.Session$Runner.operationByName(Session.java:372)
    at org.tensorflow.Session$Runner.parseOutput(Session.java:381)
    at org.tensorflow.Session$Runner.feed(Session.java:131)

找到了一個(不完美,但很簡單)的解決方法:

我剛用as_text=True導出了模型:

estimator.export_saved_model(
        export_dir_base=model_dir,
        serving_input_receiver_fn=serving_input_receiver_fn,
        as_text=True)

然后只需手動更改.pbtxt文件,以使簽名def稱為“ serving_default”

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM