简体   繁体   中英

Transformers pretrained model with dropout setting

I'm trying to use transformer's huggingface pretrained model bert-base-uncased , but I want to increace dropout. There isn't any mention to this in from_pretrained method, but colab ran the object instantiation below without any problem. I saw these dropout parameters in classtransformers.BertConfig documentation.

Am I using bert-base-uncased AND changing dropout in the correct way?

model = BertForSequenceClassification.from_pretrained(
        pretrained_model_name_or_path='bert-base-uncased',
        num_labels=2,
        output_attentions = False,
        output_hidden_states = False,
        attention_probs_dropout_prob=0.5,
        hidden_dropout_prob=0.5
    )

As Elidor00 already said it, your assumption is correct. Similarly I would suggest using a separated Config because it is easier to export and less prone to cause errors. Additionally someone in the comments ask how to use it via from_pretrained :

from transformers import BertModel, AutoConfig

configuration = AutoConfig.from_pretrained('bert-base-uncased')
configuration.hidden_dropout_prob = 0.5
configuration.attention_probs_dropout_prob = 0.5
        
bert_model = BertModel.from_pretrained(pretrained_model_name_or_path = 'bert-base-uncased', 
config = configuration)

Yes this is correct, but note that there are two dropout parameters and that you are using a specific Bert model, that is BertForSequenceClassification .

Also as suggested by the documentation you could first define the configuration and then the way in the following way:

from transformers import BertModel, BertConfig

# Initializing a BERT bert-base-uncased style configuration
configuration = BertConfig()

# Initializing a model from the bert-base-uncased style configuration
model = BertModel(configuration)

# Accessing the model configuration
configuration = model.config

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM