簡體   English   中英

如何為每個句子在 bert model 的 output 上做 avg pool?

[英]how to do avg pool on the output of bert model for each sentence?

對於分類,我們通常使用 [CLS] 來預測標簽。 但現在我有另一個請求對 bert model 中每個句子的 output 進行平均池化。 對我來說似乎有點難? 句子被 [SEP] 分割,但批次的每個樣本中每個句子的長度不相等,所以 tf.split 不適合這個問題?

一個例子如下(batch_size=2),如何得到每個句子的 avg-pooling?

[CLS] w1 w2 w3 [sep] w4 w5 [sep]

[CLS] x1 x2 [sep] x3 w4 x5 [sep]

您可以通過掩碼獲得平均值。

如果您在標記器上調用encode_plus並將return_token_type_ids設置為True ,您將獲得一個包含以下內容的字典:

  • 'input_ids' :傳遞給 model 的令牌索引
  • 'token_type_ids' :一個 0 和 1 的列表,表示哪個標記屬於哪個輸入句子。

假設您對token_type_ids進行了批處理,其中 0 是第一句,1 是第二句,填充是形狀為batch × length的變量mask中的張量中的其他內容(如-1),並且您在張量中有 BERT output在形狀×長度×768的變量output中,你可以這樣做:

first_sent_mask  = tf.cast(mask == 0, tf.float32)
first_sent_lens = tf.reduce_sum(first_sent_mask, axis=1, keepdims=True)
first_sent_mean = (
    tf.reduce_sum(output * tf.expand_dims(first_sent_mask, 2)) /
    first_sent_lens)
second_sent_mask = tf.cast(mask == 1, tf.float32)
...

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM