簡體   English   中英

TensorFlow:如何從SavedModel進行預測?

[英]TensorFlow: How to predict from a SavedModel?

我已經導出了SavedModel ,現在可以將其加載回並進行預測。 經過培訓,具有以下功能和標簽:

F1 : FLOAT32
F2 : FLOAT32
F3 : FLOAT32
L1 : FLOAT32

所以說,我想在值養活20.9, 1.8, 0.9得到一個FLOAT32預測。 我該如何完成? 我已經成功地加載了模型,但是我不確定如何訪問它以進行預測調用。

with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(
        sess,
        [tf.saved_model.tag_constants.SERVING],
        "/job/export/Servo/1503723455"
    )

    # How can I predict from here?
    # I want to do something like prediction = model.predict([20.9, 1.8, 0.9])

該問題不是此處發布的問題的重復。 這個問題集中於在任何模型類的SavedModel上進行推理的最小示例(不僅限於tf.estimator )以及指定輸入和輸出節點名稱的語法。

假設您要使用Python進行預測, SavedModelPredictor可能是加載SavedModel並獲取預測的最簡單方法。 假設您像這樣保存模型:

# Build the graph
f1 = tf.placeholder(shape=[], dtype=tf.float32)
f2 = tf.placeholder(shape=[], dtype=tf.float32)
f3 = tf.placeholder(shape=[], dtype=tf.float32)
l1 = tf.placeholder(shape=[], dtype=tf.float32)
output = build_graph(f1, f2, f3, l1)

# Save the model
inputs = {'F1': f1, 'F2': f2, 'F3': f3, 'L1': l1}
outputs = {'output': output_tensor}
tf.contrib.simple_save(sess, export_dir, inputs, outputs)

(輸入可以是任何形狀,甚至不必是圖形中的占位符或根節點)。

然后,在將使用SavedModel的Python程序中,我們可以得到如下預測:

from tensorflow.contrib import predictor

predict_fn = predictor.from_saved_model(export_dir)
predictions = predict_fn(
    {"F1": 1.0, "F2": 2.0, "F3": 3.0, "L1": 4.0})
print(predictions)

該答案顯示了如何在Java,C ++和Python中獲得預測(盡管問題集中在Estimators上,但該答案實際上獨立於SavedModel的創建方式而適用)。

對於需要保存訓練好的罐裝模型並在沒有tensorflow服務的情況下提供服務的工作示例的人,我已在此處記錄了https://github.com/tettusud/tensorflow-examples/tree/master/estimators

  1. 您可以從tf.tensorflow.contrib.predictor.from_saved_model( exported_model_path)創建一個預測變量
  2. 准備輸入

     tf.train.Example( features= tf.train.Features( feature={ 'x': tf.train.Feature( float_list=tf.train.FloatList(value=[6.4, 3.2, 4.5, 1.5]) ) } ) ) 

x是導出時在input_receiver_function中給出的輸入名稱。 例如:

feature_spec = {'x': tf.FixedLenFeature([4],tf.float32)}

def serving_input_receiver_fn():
    serialized_tf_example = tf.placeholder(dtype=tf.string,
                                           shape=[None],
                                           name='input_tensors')
    receiver_tensors = {'inputs': serialized_tf_example}
    features = tf.parse_example(serialized_tf_example, feature_spec)
    return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)

加載圖形后,它在當前上下文中可用,您可以通過它饋入輸入數據以獲得預測。 每個用例都有很大的不同,但是在代碼中添加的內容如下所示:

with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(
        sess,
        [tf.saved_model.tag_constants.SERVING],
        "/job/export/Servo/1503723455"
    )

    prediction = sess.run(
        'prefix/predictions/Identity:0',
        feed_dict={
            'Placeholder:0': [20.9],
            'Placeholder_1:0': [1.8],
            'Placeholder_2:0': [0.9]
        }
    )

    print(prediction)

在這里,您需要知道預測輸入將是什么的名稱。 如果您沒有在serving_fn中給他們一個中殿,那么它們默認為Placeholder_n ,其中n是第n個特征。

sess.run的第一個字符串參數是預測目標的名稱。 這將根據您的用例而有所不同。

tf.estimator.DNNClassifier的構造tf.estimator.DNNClassifier具有一個稱為warm_start_from的參數。 您可以為其指定SavedModel文件夾名稱,它將恢復您的會話。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM