简体   繁体   English

如何将 HuggingFace 的 Seq2seq 模型转换为 onnx 格式

[英]how to convert HuggingFace's Seq2seq models to onnx format

I am trying to convert the Pegasus newsroom in HuggingFace's transformers model to the ONNX format.我正在尝试将 HuggingFace 的变形金刚 model 中的 Pegasus 新闻编辑室转换为 ONNX 格式。 I followed this guide published by Huggingface.我遵循Huggingface 发布的指南。 After installing the prereqs, I ran this code:安装先决条件后,我运行了这段代码:

!rm -rf onnx/
from pathlib import Path
from transformers.convert_graph_to_onnx import convert

convert(framework="pt", model="google/pegasus-newsroom", output=Path("onnx/google/pegasus-newsroom.onnx"), opset=11)

and got these errors:并得到这些错误:

ValueError                                Traceback (most recent call last)
<ipython-input-9-3b37ed1ceda5> in <module>()
      3 from transformers.convert_graph_to_onnx import convert
      4 
----> 5 convert(framework="pt", model="google/pegasus-newsroom", output=Path("onnx/google/pegasus-newsroom.onnx"), opset=11)
      6 
      7 

6 frames
/usr/local/lib/python3.6/dist-packages/transformers/models/pegasus/modeling_pegasus.py in forward(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, head_mask, encoder_head_mask, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
    938             input_shape = inputs_embeds.size()[:-1]
    939         else:
--> 940             raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
    941 
    942         # past_key_values_length

ValueError: You have to specify either decoder_input_ids or decoder_inputs_embeds

I have never seen this error before.我以前从未见过这个错误。 Any ideas?有任何想法吗?

Pegasus is a seq2seq model, you can't directly convert a seq2seq model (encoder-decoder model) using this method. Pegasus 是一个seq2seq model,不能使用这种方法直接转换一个seq2seq model(编码器-解码器模型)。 The guide is for BERT which is an encoder model.guide适用于编码器 model 的 BERT。 Any only encoder or only decoder transformer model can be converted using this method.任何仅编码器或仅解码器变压器 model 都可以使用此方法进行转换。

To convert a seq2seq model (encoder-decoder) you have to split them and convert them separately, an encoder to onnx and a decoder to onnx.要转换seq2seq model(编码器-解码器),您必须将它们拆分并分别转换,将编码器转换为 onnx,将解码器转换为 onnx。 you can follow this guide (it was done for T5 which is also a seq2seq model)您可以遵循本指南(它是为 T5 完成的,它也是一个seq2seq模型)

Why are you getting this error?你为什么会收到这个错误?

while converting PyTorch to onnx同时将PyTorch 转换为 onnx

_ = torch.onnx._export(
                        model,
                        dummy_input,
                        ...
                       )

you need to provide a dummy variable to both encoder and to the decoder separately .您需要分别为编码器和解码器提供一个虚拟变量。 by default when converting using this method it provides the encoder the dummy variable.默认情况下,使用此方法进行转换时,它会为编码器提供虚拟变量。 Since this method of conversion didn't accept decoder of this seq2seq model, it won't give a dummy variable to the decoder and you get the above error.由于这种转换方法不接受这个 seq2seq model 的解码器,它不会给解码器一个虚拟变量,你会得到上述错误。 ValueError: You have to specify either decoder_input_ids or decoder_inputs_embeds

The ONNX export of canonical models from Transformers library is supported out of the box in Optimum library ( pip install optimum ): Optimum库支持从 Transformers 库中导出规范模型的 ONNX( pip install optimum ):

optimum-cli export onnx --model t5-small --task seq2seq-lm-with-past --for-ort t5_small_onnx/

Which will give:这将给出:

.
└── t5_small_onnx
    ├── config.json
    ├── decoder_model.onnx
    ├── decoder_with_past_model.onnx
    ├── encoder_model.onnx
    ├── special_tokens_map.json
    ├── spiece.model
    ├── tokenizer_config.json
    └── tokenizer.json

You can check optimum export onnx --help for more details.您可以查看optimum export onnx --help以获取更多详细信息。 What's cool is that the model can then directly be used with ONNX Runtime in (eg here) ORTModelForSeq2SeqLM .很酷的是 model 然后可以直接与 ONNX 运行时一起使用(例如此处) ORTModelForSeq2SeqLM

Pegasus itself is not yet supported, but will soon be: https://github.com/huggingface/optimum/pull/620 Pegasus 本身还不支持,但很快就会支持: https://github.com/huggingface/optimum/pull/620

Disclaimer: I am a contributor to this lib.免责声明:我是这个库的贡献者。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM