[英]Tensorflow-addons seq2seq - start and end tokens in BaseDecoder or BasicDecoder
我正在编写受https://www.tensorflow.org/addons/api_docs/python/tfa/seq2seq/BasicDecoder启发的代码。
在翻译/生成中,我们实例化一个BasicDecoder
decoder_instance = tfa.seq2seq.BasicDecoder(cell=decoder.rnn_cell, \
sampler=greedy_sampler, output_layer=decoder.fc)
并使用以下参数调用此解码器
outputs, _, _ = decoder_instance(decoder_embedding_matrix, \
start_tokens = start_tokens, end_token= end_token, initial_state=decoder_initial_state)
start_tokens 和 end_token 应该是什么,它们代表什么? BaseDecoder
签名中的一个示例给出:
Example using `tfa.seq2seq.GreedyEmbeddingSampler` for inference:
>>> sampler = tfa.seq2seq.GreedyEmbeddingSampler(embedding_layer)
>>> decoder = tfa.seq2seq.BasicDecoder(
... decoder_cell, sampler, output_layer, maximum_iterations=10)
>>>
>>> initial_state = decoder_cell.get_initial_state(batch_size=batch_size, dtype=tf.float32)
>>> start_tokens = tf.fill([batch_size], 1)
>>> end_token = 2
>>>
>>> output, state, lengths = decoder(
... None, start_tokens=start_tokens, end_token=end_token, initial_state=initial_state)
>>>
>>> output.sample_id.shape
TensorShape([4, 10])
对于翻译任务,他们是
start_tokens = tf.fill([inference_batch_size], targ_lang.word_index['<start>'])
end_token = targ_lang.word_index['<end>']
在我的应用程序中,输入的字符链具有以下形式
next_char = tf.constant(['Glücklicherweise '])
input_chars = tf.strings.unicode_split(next_char, 'UTF-8')
input_ids = ids_from_chars(input_chars).to_tensor()
该模型经过训练以生成下一个令牌。 生成器应生成"lücklicherweise x"
,其中 x 代表最可能(或更详细的搜索)的下一个字符。
我在想有多种方法可以很容易理解!
输入:
### Method 1: As string
index = 0
next_char = tf.strings.substr(
input_word, index, len(input_word[0].numpy()) - index, unit="UTF8_CHAR", name=None
)
end_token = len(input_word[0].numpy())
print('next_char[0].numpy(): ' + str(next_char[0].numpy()))
def f1():
global pointer
print(input_word[pointer])
pointer = tf.add(pointer, 1)
return
def f2():
global my_index
print('add 1')
my_index = tf.add(my_index, 1)
return
r = tf.cond( tf.less_equal(my_index, pointer), f1, f2 )
输出:
### Method 1: As string
input_word[0].numpy() length: tf.Tensor([b'Gl\xc3\xbccklicherweise '], shape=(1,), dtype=string)
input_word[0].numpy() length: 18
next_char[0].numpy(): b'Gl\xc3\xbccklicherweise '
next_char[0].numpy(): b'l\xc3\xbccklicherweise '
next_char[0].numpy(): b'\xc3\xbccklicherweise '
next_char[0].numpy(): b'cklicherweise '
next_char[0].numpy(): b'klicherweise '
next_char[0].numpy(): b'licherweise '
...
### Method 2: As alphabets
tf.Tensor([b'G'], shape=(1,), dtype=string)
tf.Tensor([b'l'], shape=(1,), dtype=string)
tf.Tensor([b'\xc3\xbc'], shape=(1,), dtype=string)
tf.Tensor([b'c'], shape=(1,), dtype=string)
tf.Tensor([b'k'], shape=(1,), dtype=string)
tf.Tensor([b'l'], shape=(1,), dtype=string)
tf.Tensor([b'i'], shape=(1,), dtype=string)
tf.Tensor([b'c'], shape=(1,), dtype=string)
tf.Tensor([b'h'], shape=(1,), dtype=string)
tf.Tensor([b'e'], shape=(1,), dtype=string)
tf.Tensor([b'r'], shape=(1,), dtype=string)
tf.Tensor([b'w'], shape=(1,), dtype=string)
tf.Tensor([b'e'], shape=(1,), dtype=string)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.