简体   繁体   English

在 Tensorflow Lite C API 中注册自定义运算符

[英]Register Custom Operator in Tensorflow Lite C API

I am running tensorflow lite on Android using the C API.我正在使用 C API 在 Android 上运行 tensorflow lite。 My model requires the operator RandomStandardNormal which was recently implemented as a custom op prototype in tensorflow v2.4.0-rc0 here我的模型需要操作符RandomStandardNormal ,它最近在 tensorflow v2.4.0-rc0 here 中作为自定义操作原型实现

TfLiteInterpreterOptionsAddCustomOp() function is listed in tensorflow/lite/c/c_api_experimental.h : TfLiteInterpreterOptionsAddCustomOp()函数在tensorflow/lite/c/c_api_experimental.h 中列出

TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddCustomOp(
    TfLiteInterpreterOptions* options, const char* name,
    const TfLiteRegistration* registration, int32_t min_version,
    int32_t max_version);

Looking at this example & thread , I am trying to use TfLiteInterpreterOptionsAddCustomOp like this:看看这个例子和线程,我试图像这样使用TfLiteInterpreterOptionsAddCustomOp

// create model and interpreter options
TfLiteModel *model = TfLiteModelCreateFromFile("path/to/model.tflite");
TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();

// register custom ops
TfLiteInterpreterOptionsAddCustomOp(options, "RandomStandardNormal", Register_RANDOM_STANDARD_NORMAL(), 1, 1);

// create the interpreter
TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options);
TfLiteInterpreterAllocateTensors(interpreter);

I see that the Register_RANDOM_STANDARD_NORMAL() function is defined in the tflite::ops::custom C++ namespace in tensorflow/lite/kernels/custom_ops_register.h .我看到Register_RANDOM_STANDARD_NORMAL()函数是在tensorflow/lite/kernels/custom_ops_register.htflite::ops::custom C++ 命名空间中tflite::ops::custom But, when I try to include this in my C file the compiler complains because namespace is an unknown type in C.但是,当我尝试将它包含在我的 C 文件中时,编译器会抱怨,因为namespace是 C 中的未知类型。

How can I register a custom operator using the tensorflow lite C API?如何使用 tensorflow lite C API 注册自定义运算符? Do I need to use a C++ compiler in order to use the C API with this custom operator because it was defined in C++?我是否需要使用 C++ 编译器才能将 C API 与此自定义运算符一起使用,因为它是在 C++ 中定义的?

NOTE: I include //tensorflow/lite/kernels:custom_ops in the bazel BUILD deps when compiling libtensorflowlite_c.so注意:我在编译libtensorflowlite_c.so时在 bazel BUILD deps 中包含//tensorflow/lite/kernels:custom_ops libtensorflowlite_c.so

It looks like this was answered on Github via this workaround:看起来这是通过以下解决方法在 Github 上回答的:

https://github.com/tensorflow/tensorflow/issues/44664#issuecomment-723310060 https://github.com/tensorflow/tensorflow/issues/44664#issuecomment-723310060

On tensorflow github , @jdduke suggested a temporary workaround:在 tensorflow github 上,@jdduke 提出了一个临时解决方法:

  • add a extern "C" wrapper head to custom_ops_register.hcustom_ops_register.h添加一个extern "C"包装头
extern "C" {
TFL_CAPI_EXPORT TfLiteRegistration* TfLiteRegisterRandomStandardNormal();
}
  • add a extern "C" wrapper implementation to random_standard_normal.ccrandom_standard_normal.cc添加extern "C"包装器实现
extern "C" {
TFL_CAPI_EXPORT TfLiteRegistration* TfLiteRegisterRandomStandardNormal() {
  return tflite::ops::custom::Register_RANDOM_STANDARD_NORMAL();
}
}
  • ensure //tensorflow/lite/kernels:custom_ops is included as a dependency in tensorflow/lite/c/BUILD确保//tensorflow/lite/kernels:custom_ops作为依赖项包含在tensorflow/lite/c/BUILD
tflite_cc_shared_object(
    name = "tensorflowlite_c",
    linkopts = select({
        "//tensorflow:ios": [
            "-Wl,-exported_symbols_list,$(location //tensorflow/lite/c:exported_symbols.lds)",
        ],
        "//tensorflow:macos": [
            "-Wl,-exported_symbols_list,$(location //tensorflow/lite/c:exported_symbols.lds)",
        ],
        "//tensorflow:windows": [],
        "//conditions:default": [
            "-z defs",
            "-Wl,--version-script,$(location //tensorflow/lite/c:version_script.lds)",
        ],
    }),
    per_os_targets = True,
    deps = [
        ":c_api",
        ":c_api_experimental",
        ":exported_symbols.lds",
        ":version_script.lds",
        "//tensorflow/lite/kernels:custom_ops", # here
    ],
)
  • modify my C++ code to call this new wrapper function修改我的 C++ 代码以调用这个新的包装函数
TfLiteInterpreterOptionsAddCustomOp(options, "RandomStandardNormal", TfLiteRegisterRandomStandardNormal(), 1, 1);

And it worked!它奏效了! My tensors finally allocated on android :)我的张量终于在 android 上分配了 :)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM