简体   繁体   English

如何将使用 Keras model 训练的 Tensorflow 2.* 转换为.onnx 格式?

[英]How to convert Tensorflow 2.* trained with Keras model to .onnx format?

I use the Python 3.7.4 with TensorFlow 2.0 and Keras 2.2.4-tf to train my own CNN model. I use the Python 3.7.4 with TensorFlow 2.0 and Keras 2.2.4-tf to train my own CNN model. Everything goes fine.一切顺利。 I can use eg model.save(my_model), and then use it in other Python scripts.我可以使用例如 model.save(my_model),然后在其他 Python 脚本中使用它。 Problem appears when I want to use trained model in OpenCV with its DNN module in C++.当我想在 OpenCV 中使用经过训练的 model 及其在 C++ 中的 DNN 模块时,会出现问题。 cv::dnn:readNetFromTensorflow(model.pb, model.pbtxt), takes as you can see two arguments, and I can't get the second.pbtxt file. cv::dnn:readNetFromTensorflow(model.pb, model.pbtxt),你可以看到两个arguments,我不能得到第二个.pbtxt文件。 So I decide to use.onnx format, because of its flexibility.所以我决定使用.onnx 格式,因为它的灵活性。 The problem is that existing libraries keras2onnx takes only model from TensorFlow 1.*, and I want to avoid working with it.问题是现有库 keras2onnx 仅从 TensorFlow 1.* 获取 model,我想避免使用它。 Example of code to convert it is presented below:转换它的代码示例如下所示:

import tensorflow as tf
import onnx
import keras2onnx
model = tf.keras.models.load_model(my_model_folder_path)
onnx_model = keras2onnx.convert_keras(model, model.name)
onnx.save_model(onnx_model, model_name_onnx)

Is there some other ways to convert such model to onnx format?还有其他方法可以将这种 model 转换为 onnx 格式吗?

The latest version of keras2onnx (in github master) supports TensorFlow 2.最新版本的keras2onnx(在github master)支持TensorFlow 2。

You can install it like this:你可以像这样安装它:

pip install git+https://github.com/microsoft/onnxconverter-common
pip install git+https://github.com/onnx/keras-onnx

You need to create a file which can hold ONNX object.您需要创建一个可以容纳 ONNX object 的文件。 Visit https://github.com/onnx/tutorials/blob/master/tutorials/OnnxTensorflowExport.ipynb访问https://github.com/onnx/tutorials/blob/master/tutorials/OnnxTensorflowExport.ipynb

import tensorflow as tf
import onnx
import keras2onnx
model = tf.keras.models.load_model('Model.h5')
onnx_model = keras2onnx.convert_keras(model, model.name)

file = open("Sample_model.onnx", "wb")
file.write(onnx_model.SerializeToString())
file.close()

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM