简体   繁体   中英

Tensorflow : How to use speech recognition model trained in python in java

I have a tensorflow model trained in python by following this article. After training I have generated the frozen graph. Now I need to use this graph and generate recognition on a JAVA based application. For this I was looking at the following example . However I failed to understand is to how to collect my output. I know that I need to provide 3 inputs to the graph.

From the example given on the official tutorial I have read the code that is based on python.

def run_graph(wav_data, labels, input_layer_name, output_layer_name,
              num_top_predictions):
  """Runs the audio data through the graph and prints predictions."""
  with tf.Session() as sess:
    # Feed the audio data as input to the graph.
    #   predictions  will contain a two-dimensional array, where one
    #   dimension represents the input image count, and the other has
    #   predictions per class
    softmax_tensor = sess.graph.get_tensor_by_name(output_layer_name)
    predictions, = sess.run(softmax_tensor, {input_layer_name: wav_data})

    # Sort to show labels in order of confidence
    top_k = predictions.argsort()[-num_top_predictions:][::-1]
    for node_id in top_k:
      human_string = labels[node_id]
      score = predictions[node_id]
      print('%s (score = %.5f)' % (human_string, score))

    return 0

Can someone help me to understand the tensorflow java api?

The literal translation of the Python code you listed above would be something like this:

public static float[][] getPredictions(Session sess, byte[] wavData, String inputLayerName, String outputLayerName) {
  try (Tensor<String> wavDataTensor = Tensors.create(wavData);
       Tensor<Float> predictionsTensor = sess.runner()
                    .feed(inputLayerName, wavDataTensor)
                    .fetch(outputLayerName)
                    .run()
                    .get(0)
                    .expect(Float.class)) {
    float[][] predictions = new float[(int)predictionsTensor.shape(0)][(int)predictionsTensor.shape(1)];
    predictionsTensor.copyTo(predictions);
    return predictions;
  }
}

The returned predictions array will have the "confidence" values of each of the predictions, and you'll have to run the logic to compute the "top K" on it similar to how the Python code is using numpy ( .argsort() ) to do so on what sess.run() returned.

From a cursory reading of the tutorial page and code, it seems that predictions will have 1 row and 12 columns (one for each hotword). I got this from the following Python code:

import tensorflow as tf

graph_def = tf.GraphDef()
with open('/tmp/my_frozen_graph.pb', 'rb') as f:
  graph_def.ParseFromString(f.read())

output_layer_name = 'labels_softmax:0'

tf.import_graph_def(graph_def, name='')
print(tf.get_default_graph().get_tensor_by_name(output_layer_name).shape)

Hope that helps.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM