(I'm following this pytorch tutorial about BERT word embeddings, and in the tutorial the author is access the intermediate layers of the BERT model.)
What I want is to access the last, lets say, 4 last layers of a single input token of the BERT model in TensorFlow2 using HuggingFace's Transformers library. Because each layer outputs a vector of length 768, so the last 4 layers will have a shape of 4*768=3072
(for each token).
How can I implement this in TF/keras/TF2, to get the intermediate layers of pretrained model for an input token? (later I will try to get the tokens for each token in a sentence, but for now one token is enough).
I'm using the HuggingFace's BERT model:
!pip install transformers
from transformers import (TFBertModel, BertTokenizer)
bert_model = TFBertModel.from_pretrained("bert-base-uncased") # Automatically loads the config
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
sentence_marked = "hello"
tokenized_text = bert_tokenizer.tokenize(sentence_marked)
indexed_tokens = bert_tokenizer.convert_tokens_to_ids(tokenized_text)
print (indexed_tokens)
>> prints [7592]
The output is a token ( [7592]
), which should be the input of the for the BERT model.
The third element of the BERT model's output is a tuple which consists of output of embedding layer as well as the intermediate layers hidden states. From documentation :
hidden_states (
tuple(tf.Tensor)
, optional, returned whenconfig.output_hidden_states=True
): tuple oftf.Tensor
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.Hidden-states of the model at the output of each layer plus the initial embedding outputs.
For the bert-base-uncased
model, the config.output_hidden_states
is by default True
. Therefore, to access hidden states of the 12 intermediate layers, you can do the following:
outputs = bert_model(input_ids, attention_mask)
hidden_states = outputs[2][1:]
There are 12 elements in hidden_states
tuple corresponding to all the layers from beginning to the last, and each of them is an array of shape (batch_size, sequence_length, hidden_size)
. So, for example, to access the hidden state of third layer for the fifth token of all the samples in the batch, you can do: hidden_states[2][:,4]
.
Note that if the model you are loading does not return the hidden states by default, then you can load the config using BertConfig
class and pass output_hidden_state=True
argument, like this:
config = BertConfig.from_pretrained("name_or_path_of_model",
output_hidden_states=True)
bert_model = TFBertModel.from_pretrained("name_or_path_of_model",
config=config)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.