简体   繁体   中英

Tensorflow-addons seq2seq - start and end tokens in BaseDecoder or BasicDecoder

I am writing code inspired from https://www.tensorflow.org/addons/api_docs/python/tfa/seq2seq/BasicDecoder .

In the translation/generation we instantiate a BasicDecoder

  decoder_instance = tfa.seq2seq.BasicDecoder(cell=decoder.rnn_cell, \
         sampler=greedy_sampler, output_layer=decoder.fc)

and call this decoder with the followings args

outputs, _, _ = decoder_instance(decoder_embedding_matrix, \
      start_tokens = start_tokens, end_token= end_token, initial_state=decoder_initial_state)

What should be start_tockens and end_token, what do they represent? An example in the BaseDecoder 's signature gives :

 Example using `tfa.seq2seq.GreedyEmbeddingSampler` for inference:
 >>> sampler = tfa.seq2seq.GreedyEmbeddingSampler(embedding_layer)
 >>> decoder = tfa.seq2seq.BasicDecoder(
    ...     decoder_cell, sampler, output_layer, maximum_iterations=10)
 >>>
 >>> initial_state = decoder_cell.get_initial_state(batch_size=batch_size, dtype=tf.float32)
 >>> start_tokens = tf.fill([batch_size], 1)
 >>> end_token = 2
 >>>
 >>> output, state, lengths = decoder(
    ...     None, start_tokens=start_tokens, end_token=end_token, initial_state=initial_state)
 >>>
 >>> output.sample_id.shape
 TensorShape([4, 10])

For the translation task they are

start_tokens = tf.fill([inference_batch_size], targ_lang.word_index['<start>'])
end_token = targ_lang.word_index['<end>']

In my application, the input chain of characters has the form

next_char = tf.constant(['Glücklicherweise '])
input_chars = tf.strings.unicode_split(next_char, 'UTF-8')
input_ids = ids_from_chars(input_chars).to_tensor()

The model is trained to generate the next token. The generator should produce "lücklicherweise x" where x stands for the most probable (or some more elaborated search) next character.

I am thinking there are multiple ways this one is easy to understand !

Input:

### Method 1: As string
index = 0
next_char = tf.strings.substr(
    input_word, index, len(input_word[0].numpy()) - index, unit="UTF8_CHAR", name=None
)
end_token = len(input_word[0].numpy())
print('next_char[0].numpy(): ' + str(next_char[0].numpy()))

def f1(): 
    global pointer
    print(input_word[pointer])
    pointer = tf.add(pointer, 1)
    return
def f2(): 
    global my_index
    print('add 1')
    my_index = tf.add(my_index, 1)
    return

r = tf.cond( tf.less_equal(my_index, pointer), f1, f2 )

Output:

### Method 1: As string
input_word[0].numpy() length: tf.Tensor([b'Gl\xc3\xbccklicherweise '], shape=(1,), dtype=string)
input_word[0].numpy() length: 18
next_char[0].numpy(): b'Gl\xc3\xbccklicherweise '
next_char[0].numpy(): b'l\xc3\xbccklicherweise '
next_char[0].numpy(): b'\xc3\xbccklicherweise '
next_char[0].numpy(): b'cklicherweise '
next_char[0].numpy(): b'klicherweise '
next_char[0].numpy(): b'licherweise '
...

### Method 2: As alphabets
tf.Tensor([b'G'], shape=(1,), dtype=string)
tf.Tensor([b'l'], shape=(1,), dtype=string)
tf.Tensor([b'\xc3\xbc'], shape=(1,), dtype=string)
tf.Tensor([b'c'], shape=(1,), dtype=string)
tf.Tensor([b'k'], shape=(1,), dtype=string)
tf.Tensor([b'l'], shape=(1,), dtype=string)
tf.Tensor([b'i'], shape=(1,), dtype=string)
tf.Tensor([b'c'], shape=(1,), dtype=string)
tf.Tensor([b'h'], shape=(1,), dtype=string)
tf.Tensor([b'e'], shape=(1,), dtype=string)
tf.Tensor([b'r'], shape=(1,), dtype=string)
tf.Tensor([b'w'], shape=(1,), dtype=string)
tf.Tensor([b'e'], shape=(1,), dtype=string)

... 编码器-解码器

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM