简体   繁体   中英

How to use tensorflow ctc beam search properly?

I want to perform CTC Beam Search on (the output of an ASR model that gives) matrices of phoneme probability values. Tensorflow has a CTC Beam Search implementation but it's poorly documented and I fail to make a working example. I want to write a code to use it as a benchmark.

Here is my code so far:

import numpy as np
import tensorflow as tf

def decode_ctcBeam(matrix, classes):
      matrix = np.reshape(matrix, (matrix.shape[0], 1,matrix.shape[1]))
      aa_ctc_blank_aa_logits = tf.constant(matrix)
      sequence_length = tf.constant(np.array([len(matrix)], dtype=np.int32))

      (decoded_list,), log_probabilities = tf.nn.ctc_beam_search_decoder(inputs=aa_ctc_blank_aa_logits,
                                          sequence_length=sequence_length,
                                          merge_repeated=True,
                                          beam_width=25)

      out = list(tf.Session().run(tf.sparse_tensor_to_dense(decoded_list)[0]))    
      print(out)

      return out

if __name__ == '__main__':
    classes = ['AA', 'B', 'CH']
    mat = np.array([[0.4, 0, 0.6, 0.2], [0.4, 0, 0.6, 0.2]], dtype=np.float32)

    actual = decode_ctcBeam (mat, classes)

I'm having issues with understanding the code:

  • in the example mat is shaped (2, 4), but the tensorflow module needs a (2, 1, 4) shape, so I reshape mat with matrix = np.reshape(matrix, (matrix.shape[0], 1,matrix.shape[1])) but what does this mean mathematically? is mat and matrix the same? Or I'm mixing things up here? 1 in the middle is the batch size in my understanding.
  • the decode_ctcBeam function returns with a list, in the example it gives [2], which should mean 'CH' from the defined classes. How do I generalize this and find the recognized phoneme sequences if I have a larger input matrix and let's say 40 phonemes?

Looking forward to your answers / comments! Thanks!

the TF documentation is wrong - beam search with beam width 1 is NOT the same as greedy decoding (I created an issue about this some time ago ).

Then, instead of np.reshape you could simply use np.transpose to reorder the dimensions, and then add a dimension for the batch size with size 1 with np.expand_dims.

Finally, regarding the TF beam search implementation: yes, the documentation is not very good. I used the implementation in a text recognition model, I point you to the lines relevant for you:

  • Create TF beam search operation : take care that merge_repeated=False, as the default setting of TF (which is True) does not make sense for 99.99999% of all relevant use cases. Just follow the variable names of the passed arguments to see how they look like, eg the input matrix is ctcIn3dTBC which is a transposed version of the RNN output
  • Transform output of beam search to a char string : the operation returns a list of sparse tensors, which have to be decoded to the char string

So, I've made some progress since I asked the question, but still haven't figured out how to use the Tensorflow has a CTC Beam Search properly. It seams that setting the top_paths = 1 and beam_width = 1 does give back the greedy search expected output in a list of ints, that can be easily transformed into required phonemes stored in classes . The output in this case is:

-------Greedy---------

Output int list

[1, 22, 39, 14, 32, 8]

['AE', 'N', ' ', 'G', 'UH', 'D']

In the case of Beam Search the results are bad

-------Beam Search----------

Output int list

[26, 19, 9, 28, 5, 0, 2, 31, 1, 22, 39, 14, 32, 20, 8, 16, 39, 30, 37, 8]

['P', 'K', 'DH', 'S', 'AY', 'AA', 'AH', 'TH', 'AE', 'N', ' ', 'G', 'UH', 'L', 'D', 'IH', ' ', 'T', 'Z', 'D']

The reference is 'I'm good'. The list of [1, 22, 39, 14, 32, 8] is inside the Beam search result, the other parts should be the alternative roots? It's pretty suspicious to me. Anyone have any ideas?

import numpy as np
import tensorflow as tf
import Classes

def decode_ctcBeam(matrix, classes):  
    matrix = np.reshape(matrix, (matrix.shape[0], 1,matrix.shape[1]))
    aa_ctc_blank_aa_logits = tf.constant(matrix)
    sequence_length = tf.constant(np.array([len(matrix)], dtype=np.int32))
    
    (decoded_list,), log_probabilities = tf.nn.ctc_beam_search_decoder(inputs=aa_ctc_blank_aa_logits,
                                              sequence_length=sequence_length,
                                              merge_repeated=True,
                                              top_paths=1,
                                              beam_width=4)

    out = list(tf.Session().run(tf.sparse_tensor_to_dense(decoded_list)[0]))
    print("Output int list")
    print(out)
    seq_list = get_seq_from_list(out, classes)
    return seq_list
        
def decode_ctcgreedy(matrix, classes):
    
    matrix = np.reshape(matrix, (matrix.shape[0], 1,matrix.shape[1]))
    
    aa_ctc_blank_aa_logits = tf.constant(matrix)
    sequence_length = tf.constant(np.array([len(matrix)], dtype=np.int32))

    (decoded_list,), log_probabilities = tf.nn.ctc_beam_search_decoder(inputs=aa_ctc_blank_aa_logits,
                                              sequence_length=sequence_length,
                                              merge_repeated=True,
                                              top_paths=1,
                                              beam_width=1)

    out = list(tf.Session().run(tf.sparse_tensor_to_dense(decoded_list)[0]))
    print("Output int list")
    print(out)
    seq_list = get_seq_from_list(out, classes)
    
    return seq_list

def get_seq_from_list(int_list, classes):
    out_list = []
    for i in range(0, len(int_list)):        
        out_list.append(classes[int_list[i]])
        
    return out_list

if __name__ == '__main__':

    mat = np.load('../npy_files/a1003.npy')
    classes = Classes.get_classes()
    
    print("-------Greedy---------")
    actual = decode_ctcgreedy(mat, classes)
    print(actual)    
    
    print("\n-------Beam Search----------")
    actual = decode_ctcBeam(mat, classes)
    print(actual)    

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM