简体   繁体   中英

How to get intermediate layers' output of pre-trained BERT model in HuggingFace Transformers library?

(I'm following this pytorch tutorial about BERT word embeddings, and in the tutorial the author is access the intermediate layers of the BERT model.)

What I want is to access the last, lets say, 4 last layers of a single input token of the BERT model in TensorFlow2 using HuggingFace's Transformers library. Because each layer outputs a vector of length 768, so the last 4 layers will have a shape of 4*768=3072 (for each token).

How can I implement this in TF/keras/TF2, to get the intermediate layers of pretrained model for an input token? (later I will try to get the tokens for each token in a sentence, but for now one token is enough).

I'm using the HuggingFace's BERT model:

!pip install transformers
from transformers import (TFBertModel, BertTokenizer)

bert_model = TFBertModel.from_pretrained("bert-base-uncased")  # Automatically loads the config
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
sentence_marked = "hello"
tokenized_text = bert_tokenizer.tokenize(sentence_marked)
indexed_tokens = bert_tokenizer.convert_tokens_to_ids(tokenized_text)

print (indexed_tokens)
>> prints [7592]

The output is a token ( [7592] ), which should be the input of the for the BERT model.

The third element of the BERT model's output is a tuple which consists of output of embedding layer as well as the intermediate layers hidden states. From documentation :

hidden_states ( tuple(tf.Tensor) , optional, returned when config.output_hidden_states=True ): tuple of tf.Tensor (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size) .

Hidden-states of the model at the output of each layer plus the initial embedding outputs.

For the bert-base-uncased model, the config.output_hidden_states is by default True . Therefore, to access hidden states of the 12 intermediate layers, you can do the following:

outputs = bert_model(input_ids, attention_mask)
hidden_states = outputs[2][1:]

There are 12 elements in hidden_states tuple corresponding to all the layers from beginning to the last, and each of them is an array of shape (batch_size, sequence_length, hidden_size) . So, for example, to access the hidden state of third layer for the fifth token of all the samples in the batch, you can do: hidden_states[2][:,4] .


Note that if the model you are loading does not return the hidden states by default, then you can load the config using BertConfig class and pass output_hidden_state=True argument, like this:

config = BertConfig.from_pretrained("name_or_path_of_model",
                                    output_hidden_states=True)

bert_model = TFBertModel.from_pretrained("name_or_path_of_model",
                                         config=config)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM