[英]Adding new op in tensorflow - Shape functions
I'm trying to add a new operation in Tensorflow where I have two inputs, namely a 3D tensor and a constant, which outputs a 4D tensor. 我正在尝试在Tensorflow中添加一个新操作,其中有两个输入,即3D张量和一个常数,输出4D张量。 The 4D tensor is obtained by replicating the 3D tensor a number of times defined by the constant.
通过将3D张量复制常数定义的次数,可以得到4D张量。 The shape function is implemented in the following way:
shape函数通过以下方式实现:
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c)
{
::tensorflow::shape_inference::ShapeHandle output;
::tensorflow::shape_inference::ShapeHandle out1 = c->Vector(::tensorflow::shape_inference::DimensionOrConstant(5));
TF_RETURN_IF_ERROR(c->Concatenate(c->input(0),out1,&output));
c->set_output(0,output);
return Status::OK();
})
.Doc(R"doc(
Replicating the 3D input tensor in a 4D tensor.
)doc");
I would like that the size of the fourth dimension (defined by out1 in the code) is set to the second input (namely the constant value). 我想将第四维的大小(由代码中的out1定义)设置为第二个输入(即常量值)。 How to do it?
怎么做?
Perhaps MakeShapeFromShapeTensor
is what you're looking for? 也许
MakeShapeFromShapeTensor
是您要找的东西? Something like: 就像是:
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c)
{
::tensorflow::shape_inference::ShapeHandle n;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &n));
::tensorflow::shape_inference::ShapeHandle out;
TF_RETURN_IF_ERROR(c->Concatenate(n, c->input(0), &out));
c->set_output(0, out);
return Status::OK();
})
That said, you probably know this, but just to be sure: Element-wise arithmetic operations in TensorFlow support broadcasting , so at least in those case you shouldn't need this custom op. 也就是说,您可能知道这一点,但是请确保: TensorFlow中的逐元素算术运算支持broadcast ,因此至少在这种情况下,您不需要此自定义操作。
For other cases, you could also combine tf.tile
, tf.shape
, tf.concat
and tf.reshape
to achieve the same effect. 对于其他情况,您还可以结合使用
tf.tile
, tf.shape
, tf.concat
和tf.reshape
来达到相同的效果。 For example, the following creates a matrix by repeating a vector: 例如,以下通过重复向量创建矩阵:
import tensorflow as tf
oneD = tf.constant([1,2])
n = tf.constant([5])
twoD = tf.reshape(tf.tile(oneD, n), tf.concat([n, tf.shape(oneD)], 0))
with tf.Session() as sess:
print oneD.eval()
print twoD.eval()
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.