简体   繁体   中英

How do you identify a sparse tensor for output purposes?

To get the prediction / output of my pre-trained model; the model predicts a symbol for each frame (column) of the convoluted image and it is necessary to conduct post-processing of the logits (output of the RNN) to emit the actual sequence of predicted symbols. Code for model construction can be found here .

logits = graph.get_tensor_by_name("fully_connected/BiasAdd:0")
decoded, _ = tf.nn.ctc_greedy_decoder(logits, seq_len)
prediction = sess.run(decoded,
                      feed_dict={
                          input: image,
                          seq_len: seq_lengths,
                          rnn_keep_prob: 1.0,
                      })

Prediction is a SparseTensorValue containing every predicted symbol. Decoded is a sparse tensor of non-empty tensors. Ultimately, I parse the resulting SparseTensorValue for the strings I need.

I want to use this trained model for inference either through tensorflow serving or tflite, however in order to proceed I would need to indicate the output nodes for the model. Given the nature of sparse tensors, I won't be able to indicate it by name. Is there a way for me to use this model for proper inference?

I've seen many examples of using ctc decoders such as this in a similar way for prediction, however, there were no examples of using these models for inference without closely relying on the tensorflow api, I am unsure how to proceed.

You can save your model to the tf saved_model format. After that you can use the CLI tool saved_model_cli of the package tensorflow-serving-api to inspect all model signatures with: saved_model_cli show --dir. --all saved_model_cli show --dir. --all . Withit you will see all information of the input and output shape(s). The default signature is called default_serving .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM