简体   繁体   中英

how to do avg pool on the output of bert model for each sentence?

for classification, we usually use [CLS] to predict labels. but now i have another request to do avg-pooling on the output of each sentence in bert model. it seems a little bit hard for me? sentence is split by [SEP] but lengh of each sentence in each sample of a batch is not equal, so tf.split is not fit for this problem?

an example as follows(batch_size=2), how to get the avg-pooling of each sentences?

[CLS] w1 w2 w3 [sep] w4 w5 [sep]

[CLS] x1 x2 [sep] x3 w4 x5 [sep]

You can get the averages by masking.

If you call encode_plus on the tokenizer and set return_token_type_ids to True , you will get a dictionary that contains:

  • 'input_ids' : token indices that you pass into your model
  • 'token_type_ids' : a list of 0s and 1s that says which token belongs to which input sentence.

Assuming you batched the token_type_ids , such that 0s are the first sentence, 1s are the second sentence and padding is something else (like -1) in a tensor in variable mask with shape batch × length , and you have the BERT output in a tensor in variable output of shape batch × length × 768, you can do:

first_sent_mask  = tf.cast(mask == 0, tf.float32)
first_sent_lens = tf.reduce_sum(first_sent_mask, axis=1, keepdims=True)
first_sent_mean = (
    tf.reduce_sum(output * tf.expand_dims(first_sent_mask, 2)) /
    first_sent_lens)
second_sent_mask = tf.cast(mask == 1, tf.float32)
...

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM