簡體   English   中英

Tensorflow:如何在java中使用在python中訓練的語音識別模型

[英]Tensorflow : How to use speech recognition model trained in python in java

我按照這篇文章在 python 中訓練了一個 tensorflow 模型 訓練后,我生成了凍結圖。 現在我需要使用這個圖並在基於JAVA的應用程序上生成識別。 為此,我正在查看以下示例 但是我不明白是如何收集我的輸出。 我知道我需要為圖表提供 3 個輸入。

從官方教程中給出的示例中,我閱讀了基於 python 的代碼。

def run_graph(wav_data, labels, input_layer_name, output_layer_name,
              num_top_predictions):
  """Runs the audio data through the graph and prints predictions."""
  with tf.Session() as sess:
    # Feed the audio data as input to the graph.
    #   predictions  will contain a two-dimensional array, where one
    #   dimension represents the input image count, and the other has
    #   predictions per class
    softmax_tensor = sess.graph.get_tensor_by_name(output_layer_name)
    predictions, = sess.run(softmax_tensor, {input_layer_name: wav_data})

    # Sort to show labels in order of confidence
    top_k = predictions.argsort()[-num_top_predictions:][::-1]
    for node_id in top_k:
      human_string = labels[node_id]
      score = predictions[node_id]
      print('%s (score = %.5f)' % (human_string, score))

    return 0

有人可以幫我理解tensorflow java api嗎?

您上面列出的 Python 代碼的字面翻譯將是這樣的:

public static float[][] getPredictions(Session sess, byte[] wavData, String inputLayerName, String outputLayerName) {
  try (Tensor<String> wavDataTensor = Tensors.create(wavData);
       Tensor<Float> predictionsTensor = sess.runner()
                    .feed(inputLayerName, wavDataTensor)
                    .fetch(outputLayerName)
                    .run()
                    .get(0)
                    .expect(Float.class)) {
    float[][] predictions = new float[(int)predictionsTensor.shape(0)][(int)predictionsTensor.shape(1)];
    predictionsTensor.copyTo(predictions);
    return predictions;
  }
}

返回的predictions數組將具有每個預測的“置信度”值,您必須運行邏輯來計算其上的“前 K”,類似於 Python 代碼使用 numpy ( .argsort() )的方式對sess.run()返回的內容執行此操作。

從對教程頁面和代碼的粗略閱讀來看, predictions似乎有 1 行和 12 列(每個熱門詞一個)。 我從以下 Python 代碼中得到了這個:

import tensorflow as tf

graph_def = tf.GraphDef()
with open('/tmp/my_frozen_graph.pb', 'rb') as f:
  graph_def.ParseFromString(f.read())

output_layer_name = 'labels_softmax:0'

tf.import_graph_def(graph_def, name='')
print(tf.get_default_graph().get_tensor_by_name(output_layer_name).shape)

希望有幫助。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM