[英]How to use tf.keras.backend.ctc_decode in a sequential model?
我正在嘗試創建一個離線手寫識別系統。 由於我是初學者,我決定嘗試重新創建Harald Scheidl 的中篇文章中描述的 model 。 model 框架如下圖所示。
**我的問題如下:
想象一下,您有一個編碼器-解碼器 model,它輸出 position t 在 [0, T) 范圍內的單詞的概率分布。 有了這個輸出,你可以在 position t 處用概率最高的單詞組成一個句子,這種方法稱為greedy 。 Greedy在分類任務上效果很好,但對於句子生成,output 可能看起來有點奇怪。 另一方面,您可以使用束搜索。 束搜索很容易理解(我在這里鏈接了一個很好的資源來理解它),簡而言之,我們可以說束搜索通過僅乘以概率來計算可能序列的可能性來尋找最可能的 output 序列 S,並選擇最可能的:p(s) = p(0) * p(1) *... * p(T)。
CTC 損失有點特殊,因此有一個定制的 output 也有額外的字符,表明 position t 與以前的相同。 使用ctd_decoder
中的 ctd_decoder,您可以使用貪婪或波束搜索方法對 CTC output 序列進行解碼。
1)您在 model.predict 之后通過 keras 使用 ctc 解碼
out = K.get_value(K.ctc_decode(prediction,
input_length=np.ones(prediction.shape[0])*prediction.shape[1],
greedy=True)[0][0])
2)是的,在定義自定義 ctc 損失后,我們在 model.compile(loss='') 中調用它
model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer =
'adam')
在 MODEL 架構中使用 CTC假設您有一個具有以下架構的 model
inputs = Input(shape=(32,128,1))
#any number of conv and rnn layers
outputs = Dense(len(char_list)+1, activation = 'softmax')(blstm_2)
# model to be used at test time
act_model = Model(inputs, outputs)
現在我們定義 ctc 損失並准備 ctc 輸入所需的標簽,我們將 ctc 損失稱為 keras api
labels = Input(name='the_labels', shape=[max_label_len], dtype='float32')
input_length = Input(name='input_length', shape=[1], dtype='int64')
label_length = Input(name='label_length', shape=[1], dtype='int64')
def ctc_lambda_func(args):
y_pred, labels, input_length, label_length = args
return K.ctc_batch_cost(labels, y_pred, input_length, label_length)
loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')
([outputs, labels, input_length, label_length])
#model to be used at training time
model = Model(inputs=[inputs, labels, input_length, label_length],
outputs=loss_out)
現在我們編譯它
model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer =
'adam')
使用 CTC 解碼器在測試集上測試經過訓練的 MODEL
# load the saved best model weights
act_model.load_weights('best_model30epoch.hdf5')
# predict outputs on validation images
prediction = act_model.predict(test_img[:15])
# use CTC decoder
out = K.get_value(K.ctc_decode(prediction,
input_length=np.ones(prediction.shape[0])*prediction.shape[1],
greedy=True)[0][0])
# see the results
i = 0
for x in out:
img = mpimg.imread(file_img[i])
imgplot = plt.imshow(img)
plt.show()
print("original_text = ", test_orig_txt[i])
print("predicted text = ", end = '')
for p in x:
if int(p) != -1:
print(char_list[int(p)], end = '')
print('\n')
i+=1
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.