简体   繁体   English

如何在顺序 model 中使用 tf.keras.backend.ctc_decode?

[英]How to use tf.keras.backend.ctc_decode in a sequential model?

I am trying to create a offline handwriting recognition system.我正在尝试创建一个离线手写识别系统。 Since I am a beginner, I decided to try an recreate the model described in a medium article by Harald Scheidl .由于我是初学者,我决定尝试重新创建Harald Scheidl 的中篇文章中描述的 model The model framework is shown in the following image. model 框架如下图所示。


**My questions are as follows: **我的问题如下:

  1. How do I use CTC decode from keras .如何使用来自 keras 的 CTC 解码 I am using a sequential model with keras layers.我正在使用带有 keras 层的顺序 model。
  2. Will CTC loss be the loss function in the compile parameter?** CTC loss会是compile参数中的loss function吗?**

Imagine that you have a encoder-decoder model that outputs a probability distribution of a word being at position t being t in range [0, T).想象一下,您有一个编码器-解码器 model,它输出 position t 在 [0, T) 范围内的单词的概率分布。 With this outputs, you can compose a sentence taking the word with highest probability at position t, this approach is called greedy .有了这个输出,你可以在 position t 处用概率最高的单词组成一个句子,这种方法称为greedy Greedy works well on classification task but for sentence generation the output may look a bit weird. Greedy在分类任务上效果很好,但对于句子生成,output 可能看起来有点奇怪。 In the other hand, you can use a beam search.另一方面,您可以使用束搜索。 Beam search is pretty simple to understand (I link here a good resource to understand it), in few words we can say that beam search looks for the most probable output sequence S by computing the likelihood of the possible sequences by just multiplying the probabilities, and selecting the most probable: p(s) = p(0) * p(1) *... * p(T).束搜索很容易理解(我在这里链接了一个很好的资源来理解它),简而言之,我们可以说束搜索通过仅乘以概率来计算可能序列的可能性来寻找最可能的 output 序列 S,并选择最可能的:p(s) = p(0) * p(1) *... * p(T)。

The CTC loss is a bit special and in consequence has a custom output having also the extra character to indicate that the position t is the same as the previous ones. CTC 损失有点特殊,因此有一个定制的 output 也有额外的字符,表明 position t 与以前的相同。 With ctd_decoder from the keras API you can decode the CTC output sequence by either using a greedy or beam search approach.使用ctd_decoder中的 ctd_decoder,您可以使用贪婪或波束搜索方法对 CTC output 序列进行解码。

1)you use ctc decode through keras after model.predict 1)您在 model.predict 之后通过 keras 使用 ctc 解码

out = K.get_value(K.ctc_decode(prediction, 
input_length=np.ones(prediction.shape[0])*prediction.shape[1],
                 greedy=True)[0][0])

2)yes after defining custom ctc loss we call it in model.compile(loss='') 2)是的,在定义自定义 ctc 损失后,我们在 model.compile(loss='') 中调用它

model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer = 
'adam')

USING CTC IN MODEL ARCHITECTURE Suppose you have a model with folowing archi在 MODEL 架构中使用 CTC假设您有一个具有以下架构的 model

inputs = Input(shape=(32,128,1))
#any number of conv and rnn layers
outputs = Dense(len(char_list)+1, activation = 'softmax')(blstm_2)

# model to be used at test time
act_model = Model(inputs, outputs)

now we define ctc loss and prepare the labels required for the ctc inputs, we call ctc loss by keras api现在我们定义 ctc 损失并准备 ctc 输入所需的标签,我们将 ctc 损失称为 keras api

labels = Input(name='the_labels', shape=[max_label_len], dtype='float32')
input_length = Input(name='input_length', shape=[1], dtype='int64')
label_length = Input(name='label_length', shape=[1], dtype='int64')


def ctc_lambda_func(args):
   y_pred, labels, input_length, label_length = args

   return K.ctc_batch_cost(labels, y_pred, input_length, label_length)


loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc') 
([outputs, labels, input_length, label_length])

#model to be used at training time
model = Model(inputs=[inputs, labels, input_length, label_length], 
outputs=loss_out)

now we compile it现在我们编译它

model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer = 
'adam')

USING CTC DECODER TO TEST TRAINED MODEL ON TEST SET使用 CTC 解码器在测试集上测试经过训练的 MODEL

# load the saved best model weights
act_model.load_weights('best_model30epoch.hdf5')

# predict outputs on validation images
prediction = act_model.predict(test_img[:15])

# use CTC decoder
out = K.get_value(K.ctc_decode(prediction, 
input_length=np.ones(prediction.shape[0])*prediction.shape[1],
                     greedy=True)[0][0])

# see the results
i = 0
for x in out:
   img = mpimg.imread(file_img[i])
   imgplot = plt.imshow(img)
   plt.show()
   print("original_text =  ", test_orig_txt[i])
   print("predicted text = ", end = '')
for p in x:  
    if int(p) != -1:
        print(char_list[int(p)], end = '')       
print('\n')
i+=1

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM