简体   繁体   中英

How to use tf.keras.backend.ctc_decode in a sequential model?

I am trying to create a offline handwriting recognition system. Since I am a beginner, I decided to try an recreate the model described in a medium article by Harald Scheidl . The model framework is shown in the following image.


**My questions are as follows:

  1. How do I use CTC decode from keras . I am using a sequential model with keras layers.
  2. Will CTC loss be the loss function in the compile parameter?**

Imagine that you have a encoder-decoder model that outputs a probability distribution of a word being at position t being t in range [0, T). With this outputs, you can compose a sentence taking the word with highest probability at position t, this approach is called greedy . Greedy works well on classification task but for sentence generation the output may look a bit weird. In the other hand, you can use a beam search. Beam search is pretty simple to understand (I link here a good resource to understand it), in few words we can say that beam search looks for the most probable output sequence S by computing the likelihood of the possible sequences by just multiplying the probabilities, and selecting the most probable: p(s) = p(0) * p(1) *... * p(T).

The CTC loss is a bit special and in consequence has a custom output having also the extra character to indicate that the position t is the same as the previous ones. With ctd_decoder from the keras API you can decode the CTC output sequence by either using a greedy or beam search approach.

1)you use ctc decode through keras after model.predict

out = K.get_value(K.ctc_decode(prediction, 
input_length=np.ones(prediction.shape[0])*prediction.shape[1],
                 greedy=True)[0][0])

2)yes after defining custom ctc loss we call it in model.compile(loss='')

model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer = 
'adam')

USING CTC IN MODEL ARCHITECTURE Suppose you have a model with folowing archi

inputs = Input(shape=(32,128,1))
#any number of conv and rnn layers
outputs = Dense(len(char_list)+1, activation = 'softmax')(blstm_2)

# model to be used at test time
act_model = Model(inputs, outputs)

now we define ctc loss and prepare the labels required for the ctc inputs, we call ctc loss by keras api

labels = Input(name='the_labels', shape=[max_label_len], dtype='float32')
input_length = Input(name='input_length', shape=[1], dtype='int64')
label_length = Input(name='label_length', shape=[1], dtype='int64')


def ctc_lambda_func(args):
   y_pred, labels, input_length, label_length = args

   return K.ctc_batch_cost(labels, y_pred, input_length, label_length)


loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc') 
([outputs, labels, input_length, label_length])

#model to be used at training time
model = Model(inputs=[inputs, labels, input_length, label_length], 
outputs=loss_out)

now we compile it

model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer = 
'adam')

USING CTC DECODER TO TEST TRAINED MODEL ON TEST SET

# load the saved best model weights
act_model.load_weights('best_model30epoch.hdf5')

# predict outputs on validation images
prediction = act_model.predict(test_img[:15])

# use CTC decoder
out = K.get_value(K.ctc_decode(prediction, 
input_length=np.ones(prediction.shape[0])*prediction.shape[1],
                     greedy=True)[0][0])

# see the results
i = 0
for x in out:
   img = mpimg.imread(file_img[i])
   imgplot = plt.imshow(img)
   plt.show()
   print("original_text =  ", test_orig_txt[i])
   print("predicted text = ", end = '')
for p in x:  
    if int(p) != -1:
        print(char_list[int(p)], end = '')       
print('\n')
i+=1

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM