[英]Tensorflow : How to use speech recognition model trained in python in java
I have a tensorflow model trained in python by following this article.我按照这篇文章在 python 中训练了一个 tensorflow 模型。 After training I have generated the frozen graph.
训练后,我生成了冻结图。 Now I need to use this graph and generate recognition on a JAVA based application.
现在我需要使用这个图并在基于JAVA的应用程序上生成识别。 For this I was looking at the following example .
为此,我正在查看以下示例。 However I failed to understand is to how to collect my output.
但是我不明白是如何收集我的输出。 I know that I need to provide 3 inputs to the graph.
我知道我需要为图表提供 3 个输入。
From the example given on the official tutorial I have read the code that is based on python.从官方教程中给出的示例中,我阅读了基于 python 的代码。
def run_graph(wav_data, labels, input_layer_name, output_layer_name,
num_top_predictions):
"""Runs the audio data through the graph and prints predictions."""
with tf.Session() as sess:
# Feed the audio data as input to the graph.
# predictions will contain a two-dimensional array, where one
# dimension represents the input image count, and the other has
# predictions per class
softmax_tensor = sess.graph.get_tensor_by_name(output_layer_name)
predictions, = sess.run(softmax_tensor, {input_layer_name: wav_data})
# Sort to show labels in order of confidence
top_k = predictions.argsort()[-num_top_predictions:][::-1]
for node_id in top_k:
human_string = labels[node_id]
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
return 0
Can someone help me to understand the tensorflow java api?有人可以帮我理解tensorflow java api吗?
The literal translation of the Python code you listed above would be something like this:您上面列出的 Python 代码的字面翻译将是这样的:
public static float[][] getPredictions(Session sess, byte[] wavData, String inputLayerName, String outputLayerName) {
try (Tensor<String> wavDataTensor = Tensors.create(wavData);
Tensor<Float> predictionsTensor = sess.runner()
.feed(inputLayerName, wavDataTensor)
.fetch(outputLayerName)
.run()
.get(0)
.expect(Float.class)) {
float[][] predictions = new float[(int)predictionsTensor.shape(0)][(int)predictionsTensor.shape(1)];
predictionsTensor.copyTo(predictions);
return predictions;
}
}
The returned predictions
array will have the "confidence" values of each of the predictions, and you'll have to run the logic to compute the "top K" on it similar to how the Python code is using numpy ( .argsort()
) to do so on what sess.run()
returned.返回的
predictions
数组将具有每个预测的“置信度”值,您必须运行逻辑来计算其上的“前 K”,类似于 Python 代码使用 numpy ( .argsort()
)的方式对sess.run()
返回的内容执行此操作。
From a cursory reading of the tutorial page and code, it seems that predictions
will have 1 row and 12 columns (one for each hotword).从对教程页面和代码的粗略阅读来看,
predictions
似乎有 1 行和 12 列(每个热门词一个)。 I got this from the following Python code:我从以下 Python 代码中得到了这个:
import tensorflow as tf
graph_def = tf.GraphDef()
with open('/tmp/my_frozen_graph.pb', 'rb') as f:
graph_def.ParseFromString(f.read())
output_layer_name = 'labels_softmax:0'
tf.import_graph_def(graph_def, name='')
print(tf.get_default_graph().get_tensor_by_name(output_layer_name).shape)
Hope that helps.希望有帮助。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.