简体   繁体   中英

Adding new op in tensorflow - Shape functions

I'm trying to add a new operation in Tensorflow where I have two inputs, namely a 3D tensor and a constant, which outputs a 4D tensor. The 4D tensor is obtained by replicating the 3D tensor a number of times defined by the constant. The shape function is implemented in the following way:

.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c)
{
    ::tensorflow::shape_inference::ShapeHandle output;
    ::tensorflow::shape_inference::ShapeHandle out1 = c->Vector(::tensorflow::shape_inference::DimensionOrConstant(5));
    TF_RETURN_IF_ERROR(c->Concatenate(c->input(0),out1,&output));
    c->set_output(0,output);
    return Status::OK();
})
.Doc(R"doc(
     Replicating the 3D input tensor in a 4D tensor.
)doc");

I would like that the size of the fourth dimension (defined by out1 in the code) is set to the second input (namely the constant value). How to do it?

Perhaps MakeShapeFromShapeTensor is what you're looking for? Something like:

.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c)
{
    ::tensorflow::shape_inference::ShapeHandle n;
    TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &n));
    ::tensorflow::shape_inference::ShapeHandle out;
    TF_RETURN_IF_ERROR(c->Concatenate(n, c->input(0), &out));
    c->set_output(0, out);
    return Status::OK();
})

That said, you probably know this, but just to be sure: Element-wise arithmetic operations in TensorFlow support broadcasting , so at least in those case you shouldn't need this custom op.

For other cases, you could also combine tf.tile , tf.shape , tf.concat and tf.reshape to achieve the same effect. For example, the following creates a matrix by repeating a vector:

import tensorflow as tf
oneD = tf.constant([1,2])
n = tf.constant([5])
twoD = tf.reshape(tf.tile(oneD, n), tf.concat([n, tf.shape(oneD)], 0))

with tf.Session() as sess:
  print oneD.eval()
  print twoD.eval()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM