简体   繁体   中英

Token indices sequence length is longer than the specified maximum sequence length for this model (28627 > 512)

I am using BERT's Huggingface DistilBERT model as a backend for a question and answer application. The text I am using with which to train the model is one very large single text field. Even though the text field is a single string, the punctuation was left in place as a clue for BERT. When I execute the application I am getting the " Token indices sequence length error ". I am using the transformer.encodeplus() method to pass the text into the model. I have tried various mechanisms to truncate the input ids to a length <= to 512. I am currently using Windows 10 but I will also be porting the code to a Raspberry Pi 4 platform.

The code is failing at this line:

start_scores, end_scores = model(torch.tensor([input_ids]), attention_mask=torch.tensor([attention_mask]))

I am attempting to perform the truncation at this line:

encoding = tokenizer.encode_plus(question, tokenizer(context, truncation=True).input_ids)

The entire code is here:

from transformers import AutoTokenizer, DistilBertTokenizer, DistilBertForQuestionAnswering
import torch

# globals - set once used everywhere
tokenizer = None
model = None
context = ''

def establishSettings():
    global tokenizer, model, context
    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', return_token_type_ids=True, model_max_length=512)

    model = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased-distilled-squad', return_dict=False)
    # context = "Some 1,500 volcanoes are still considered potentially active around the world today 161 of those over 10 percent sit within the boundaries of the United States."

    # get the volcano corpus
    with open('volcanic.corpus', encoding="utf8") as file:
        context = file.read().replace('\n', '')

    print(len(tokenizer(context, truncation=True).input_ids))


def askQuestion(question):
    global tokenizer, model, context
    print("\nQuestion ", question)
    encoding = tokenizer.encode_plus(question, tokenizer(context, truncation=True).input_ids)
    input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
    start_scores, end_scores = model(torch.tensor([input_ids]), attention_mask=torch.tensor([attention_mask]))
    ans_tokens = input_ids[torch.argmax(start_scores): torch.argmax(end_scores) + 1]
    answer_tokens = tokenizer.convert_ids_to_tokens(ans_tokens, skip_special_tokens=True)
    #all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
    return answer_tokens


def main():
    # set the global itmes once
    establishSettings()

    # ask a question
    question = "How many potentially active volcanoes are there in the world today?"
    answer_tokens = askQuestion(question)
    print("answer_tokens: ", answer_tokens)

    if len(answer_tokens) == 0:
        answer = "Sorry, I don't have an answer for that  one.  Ask me another question about New Mexico volcanoes."
        print(answer)
    else:
        answer_tokens_to_string = tokenizer.convert_tokens_to_string(answer_tokens)
        print("\nFinal Answer : ")
        print(answer_tokens_to_string)

if __name__ == '__main__':
    main()

What is the best way to truncate the input.ids to <= 512 in length.

Edit this line:

encoding = tokenizer.encode_plus(question, tokenizer(context, truncation=True).input_ids)

to

encoding = tokenizer.encode_plus(question, tokenizer(context, truncation=True, max_length=512).input_ids)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM