繁体   English   中英

Tensorflow-addons seq2seq - BaseDecoder 或 BasicDecoder 中的开始和结束标记

[英]Tensorflow-addons seq2seq - start and end tokens in BaseDecoder or BasicDecoder

我正在编写受https://www.tensorflow.org/addons/api_docs/python/tfa/seq2seq/BasicDecoder启发的代码。

在翻译/生成中,我们实例化一个BasicDecoder

  decoder_instance = tfa.seq2seq.BasicDecoder(cell=decoder.rnn_cell, \
         sampler=greedy_sampler, output_layer=decoder.fc)

并使用以下参数调用此解码器

outputs, _, _ = decoder_instance(decoder_embedding_matrix, \
      start_tokens = start_tokens, end_token= end_token, initial_state=decoder_initial_state)

start_tokens 和 end_token 应该是什么,它们代表什么? BaseDecoder签名中的一个示例给出:

 Example using `tfa.seq2seq.GreedyEmbeddingSampler` for inference:
 >>> sampler = tfa.seq2seq.GreedyEmbeddingSampler(embedding_layer)
 >>> decoder = tfa.seq2seq.BasicDecoder(
    ...     decoder_cell, sampler, output_layer, maximum_iterations=10)
 >>>
 >>> initial_state = decoder_cell.get_initial_state(batch_size=batch_size, dtype=tf.float32)
 >>> start_tokens = tf.fill([batch_size], 1)
 >>> end_token = 2
 >>>
 >>> output, state, lengths = decoder(
    ...     None, start_tokens=start_tokens, end_token=end_token, initial_state=initial_state)
 >>>
 >>> output.sample_id.shape
 TensorShape([4, 10])

对于翻译任务,他们是

start_tokens = tf.fill([inference_batch_size], targ_lang.word_index['<start>'])
end_token = targ_lang.word_index['<end>']

在我的应用程序中,输入的字符链具有以下形式

next_char = tf.constant(['Glücklicherweise '])
input_chars = tf.strings.unicode_split(next_char, 'UTF-8')
input_ids = ids_from_chars(input_chars).to_tensor()

该模型经过训练以生成下一个令牌。 生成器应生成"lücklicherweise x" ,其中 x 代表最可能(或更详细的搜索)的下一个字符。

我在想有多种方法可以很容易理解!

输入:

### Method 1: As string
index = 0
next_char = tf.strings.substr(
    input_word, index, len(input_word[0].numpy()) - index, unit="UTF8_CHAR", name=None
)
end_token = len(input_word[0].numpy())
print('next_char[0].numpy(): ' + str(next_char[0].numpy()))

def f1(): 
    global pointer
    print(input_word[pointer])
    pointer = tf.add(pointer, 1)
    return
def f2(): 
    global my_index
    print('add 1')
    my_index = tf.add(my_index, 1)
    return

r = tf.cond( tf.less_equal(my_index, pointer), f1, f2 )

输出:

### Method 1: As string
input_word[0].numpy() length: tf.Tensor([b'Gl\xc3\xbccklicherweise '], shape=(1,), dtype=string)
input_word[0].numpy() length: 18
next_char[0].numpy(): b'Gl\xc3\xbccklicherweise '
next_char[0].numpy(): b'l\xc3\xbccklicherweise '
next_char[0].numpy(): b'\xc3\xbccklicherweise '
next_char[0].numpy(): b'cklicherweise '
next_char[0].numpy(): b'klicherweise '
next_char[0].numpy(): b'licherweise '
...

### Method 2: As alphabets
tf.Tensor([b'G'], shape=(1,), dtype=string)
tf.Tensor([b'l'], shape=(1,), dtype=string)
tf.Tensor([b'\xc3\xbc'], shape=(1,), dtype=string)
tf.Tensor([b'c'], shape=(1,), dtype=string)
tf.Tensor([b'k'], shape=(1,), dtype=string)
tf.Tensor([b'l'], shape=(1,), dtype=string)
tf.Tensor([b'i'], shape=(1,), dtype=string)
tf.Tensor([b'c'], shape=(1,), dtype=string)
tf.Tensor([b'h'], shape=(1,), dtype=string)
tf.Tensor([b'e'], shape=(1,), dtype=string)
tf.Tensor([b'r'], shape=(1,), dtype=string)
tf.Tensor([b'w'], shape=(1,), dtype=string)
tf.Tensor([b'e'], shape=(1,), dtype=string)

... 编码器-解码器

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM