[英]What Python types does TensorFlow accept for Attr's of type “tensor”?
I am defining a new Op in C++ which takes in a single attribute of type tensor
, roughly following these instructions . 我在C ++中定义了一个新的Op,它接受了
tensor
类型的单个属性,大致遵循这些指令 。 A stripped version of the Op code is below: 以下是Op代码的剥离版本:
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
REGISTER_OP("DoStuff")
.Attr("attr: tensor = { dtype: DT_FLOAT }")
.Input("in: float")
.Output("out: float");
class DoStuffOp : public OpKernel {
public:
explicit DoStuffOp(OpKernelConstruction *context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("attr", &attr_));
// ...
}
void Compute(OpKernelContext *context) override {
// ...
}
private:
Tensor attr_;
};
REGISTER_KERNEL_BUILDER(Name("DoStuff").Device(DEVICE_CPU), DoStuffOp);
I can compile the Op into a .so
file fine. 我可以将Op编译成
.so
文件。 However, I can't figure out how to successfully pass in a value for attr
. 但是,我无法弄清楚如何成功传入
attr
的值。 When I run the following in Python: 当我在Python中运行以下内容时:
import tensorflow as tf
dostufflib = tf.load_op_library('build/do_stuff.so')
sess = tf.InteractiveSession()
A = [[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]]
X = tf.Variable(tf.constant(1.0))
Y = dostufflib.do_stuff(X, A)
I get TypeError: Don't know how to convert [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] to a TensorProto for argument 'attr'
. 我得到
TypeError: Don't know how to convert [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] to a TensorProto for argument 'attr'
。 Nothing I do seems to satisfy the type conversion: list
, numpy
array, tf.Tensor
, tf.Variable
, etc. How do you pass Python variables into an Op as tensor attributes? 我没做什么似乎满足类型转换:
list
, numpy
数组, tf.Tensor
, tf.Variable
等。你如何将Python变量作为张量属性传递给Op?
After much more hunting, I found tf.contrib.util.make_tensor_proto
, a function that converts a python scalar, python list, numpy ndarray, or numpy scalar into a tf.TensorProto
object. 经过更多的搜索,我找到了
tf.contrib.util.make_tensor_proto
,这是一个将python标量,python列表,numpy ndarray或numpy标量转换为tf.TensorProto
对象的函数。 The following works: 以下作品:
A = tf.contrib.util.make_tensor_proto([[1.0, 2.0, 3.0],[4.0, 5.0, 6.0]])
X = tf.Variable(tf.constant(1.0))
Y = dostufflib.do_stuff(X, A)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.