[英]How to make custom Op in TensorFlow importable in Python?
我已經為自定義Op實現了一個內核,並將其作為custom_op.cc
放入/tensorflow/core/user_ops
中。 在Op內部,我完成所有注冊工作,例如REGISTER_OP
和REGISTER_KERNEL_BUILDER
。
然后,我在Python中為該Op實現了漸變,然后將其放在與custom_op_grad.py
相同的文件夾中。 我也在這里進行了所有注冊( @ops.RegisterGradient
)。
我創建了BUILD文件,內容如下:
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
tf_custom_op_library(
name = "custom_op.so",
srcs = ["custom_op.cc"],
)
py_library(
name = "custom_op_grad",
srcs = ["custom_op_grad.py"],
srcs_version = "PY2",
deps = [
":custom_op_grad",
"//tensorflow:tensorflow_py",
],
)
之后,我重建Tensorflow:
pip uninstall tensorflow
bazel clean
bazel build -c opt //tensorflow/tools/pip_package:build_pip_package
cp -r bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/__main__/* bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/
bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
pip install /tmp/tensorflow_pkg/tensorflow-0.8.0-py2-none-any.whl
當我嘗試使用所有這些操作時,通過調用tf.user_ops.custom_op
可以告訴我該模塊沒有該操作。
也許我還需要執行其他一些步驟? 或者我對BUILD
文件做錯了什么?
好的,我找到了解決方案。 我剛剛刪除了BUILD
文件,並且我的自定義Op已成功構建並且可以使用tensorflow.user_ops.custom_op()
在Python中導入。
要使用漸變,我必須將其代碼直接放在tensorflow/python/user_ops/user_ops.py
。 不是最優雅的解決方案,但現在可以使用。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.