簡體   English   中英

如何將 decode_batch_predictions() 方法添加到 Keras 手寫識別 OCR model?

[英]How can I add the decode_batch_predictions() method into the Keras Handwriting Recognition OCR model?

我需要將decode_batch_predictions()方法添加到Keras 手寫識別 OCR model的 output 中。 這樣做的原因是我想將 model 轉換為 TF Lite,並且我希望對 output 進行解碼,因為我沒有找到任何方法來解碼 Android 中 TF Lite 上的 output。我已經看到了類似的帖子Keras model 但它不適用於這個 model。我對 Python 了解不多,所以我很難為這個 model 調整那個帖子的答案所以我真的很感激任何幫助,謝謝!

我嘗試使用該帖子中的代碼,但它不起作用

在您的鏈接中給出的 model 的筆記本中,在prediction_model之后進行以下更改:

prediction_model = keras.models.Model(
    model.get_layer(name="image").input, model.get_layer(name="dense2").output
) # This line is present in the handwriting_recognition notebook.

def CTCDecoder():
  def decode_batch_predictions(pred):
    input_len = np.ones(pred.shape[0]) * pred.shape[1]
    # Use greedy search. For complex tasks, you can use beam search
    results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][:, :max_length]
    # Iterate over the results and get back the text
    output_text = []
    for res in results:
        #print(res)
        res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
        output_text.append(res)
    return output_text

  return tf.keras.layers.Lambda(decode_batch_predictions, name='decode')

decoded_pred_model = keras.models.Model(prediction_model.input, outputs=CTCDecoder()(prediction_model.output))

decoded_pred_model轉換為 a.tflite 並在 android 中使用。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM