简体   繁体   中英

How to mutate Tensorflow Variable in custom Op?

I'm trying to modify the simple Adding a New Op so that it doesn't create a new Tensor as a return value, but it actually mutates the input Tensor and returns that. I know this is possible because the scatter Op is doing the same thing, but looking at the scatter Op source code, I cannot figure out exactly what to do given my lack of C++ experience.

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

using namespace tensorflow;


REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });



#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {

    // Grab the input tensor
    Tensor input_tensor = context->mutable_input(0, true);
    auto input = input_tensor.flat<int32>();

    // We always return the input ref.
    context->forward_ref_input_to_ref_output(0, 0);

    // Set all but the first element of the output tensor to 0.
    const int N = input.size();
    for (int i = 1; i < N; i++) {
            input(i) = 0;
    }
  }
};

REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);

If I compile the above code and run a simple Python script to test it, I get the following error:

Python(14820,0x700003a7e000) malloc: *** error for object 0x7fd5c45a5a88: pointer being freed was not allocated
*** set a breakpoint in malloc_error_break to debug

What do I need to change in my code to accomplish my need?

I think you would better to modify the process of grabbing input and output. Actually according to your REGISTER_OP, it is not reference input, so

context->mutable_input(0, true)

would be

context->input(0)

Also, setting the output would be changed to

context->set_output(0, context->input(0))

I think it will work after setting the output.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM