簡體   English   中英

Tensorflow-addons seq2seq - BaseDecoder 或 BasicDecoder 中的開始和結束標記

[英]Tensorflow-addons seq2seq - start and end tokens in BaseDecoder or BasicDecoder

我正在編寫受https://www.tensorflow.org/addons/api_docs/python/tfa/seq2seq/BasicDecoder啟發的代碼。

在翻譯/生成中,我們實例化一個BasicDecoder

  decoder_instance = tfa.seq2seq.BasicDecoder(cell=decoder.rnn_cell, \
         sampler=greedy_sampler, output_layer=decoder.fc)

並使用以下參數調用此解碼器

outputs, _, _ = decoder_instance(decoder_embedding_matrix, \
      start_tokens = start_tokens, end_token= end_token, initial_state=decoder_initial_state)

start_tokens 和 end_token 應該是什么,它們代表什么? BaseDecoder簽名中的一個示例給出:

 Example using `tfa.seq2seq.GreedyEmbeddingSampler` for inference:
 >>> sampler = tfa.seq2seq.GreedyEmbeddingSampler(embedding_layer)
 >>> decoder = tfa.seq2seq.BasicDecoder(
    ...     decoder_cell, sampler, output_layer, maximum_iterations=10)
 >>>
 >>> initial_state = decoder_cell.get_initial_state(batch_size=batch_size, dtype=tf.float32)
 >>> start_tokens = tf.fill([batch_size], 1)
 >>> end_token = 2
 >>>
 >>> output, state, lengths = decoder(
    ...     None, start_tokens=start_tokens, end_token=end_token, initial_state=initial_state)
 >>>
 >>> output.sample_id.shape
 TensorShape([4, 10])

對於翻譯任務,他們是

start_tokens = tf.fill([inference_batch_size], targ_lang.word_index['<start>'])
end_token = targ_lang.word_index['<end>']

在我的應用程序中,輸入的字符鏈具有以下形式

next_char = tf.constant(['Glücklicherweise '])
input_chars = tf.strings.unicode_split(next_char, 'UTF-8')
input_ids = ids_from_chars(input_chars).to_tensor()

該模型經過訓練以生成下一個令牌。 生成器應生成"lücklicherweise x" ,其中 x 代表最可能(或更詳細的搜索)的下一個字符。

我在想有多種方法可以很容易理解!

輸入:

### Method 1: As string
index = 0
next_char = tf.strings.substr(
    input_word, index, len(input_word[0].numpy()) - index, unit="UTF8_CHAR", name=None
)
end_token = len(input_word[0].numpy())
print('next_char[0].numpy(): ' + str(next_char[0].numpy()))

def f1(): 
    global pointer
    print(input_word[pointer])
    pointer = tf.add(pointer, 1)
    return
def f2(): 
    global my_index
    print('add 1')
    my_index = tf.add(my_index, 1)
    return

r = tf.cond( tf.less_equal(my_index, pointer), f1, f2 )

輸出:

### Method 1: As string
input_word[0].numpy() length: tf.Tensor([b'Gl\xc3\xbccklicherweise '], shape=(1,), dtype=string)
input_word[0].numpy() length: 18
next_char[0].numpy(): b'Gl\xc3\xbccklicherweise '
next_char[0].numpy(): b'l\xc3\xbccklicherweise '
next_char[0].numpy(): b'\xc3\xbccklicherweise '
next_char[0].numpy(): b'cklicherweise '
next_char[0].numpy(): b'klicherweise '
next_char[0].numpy(): b'licherweise '
...

### Method 2: As alphabets
tf.Tensor([b'G'], shape=(1,), dtype=string)
tf.Tensor([b'l'], shape=(1,), dtype=string)
tf.Tensor([b'\xc3\xbc'], shape=(1,), dtype=string)
tf.Tensor([b'c'], shape=(1,), dtype=string)
tf.Tensor([b'k'], shape=(1,), dtype=string)
tf.Tensor([b'l'], shape=(1,), dtype=string)
tf.Tensor([b'i'], shape=(1,), dtype=string)
tf.Tensor([b'c'], shape=(1,), dtype=string)
tf.Tensor([b'h'], shape=(1,), dtype=string)
tf.Tensor([b'e'], shape=(1,), dtype=string)
tf.Tensor([b'r'], shape=(1,), dtype=string)
tf.Tensor([b'w'], shape=(1,), dtype=string)
tf.Tensor([b'e'], shape=(1,), dtype=string)

... 編碼器-解碼器

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM