I'm trying to modify the simple Adding a New Op so that it doesn't create a new Tensor as a return value, but it actually mutates the input Tensor and returns that. I know this is possible because the scatter
Op
is doing the same thing, but looking at the scatter
Op
source code, I cannot figure out exactly what to do given my lack of C++
experience.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
using namespace tensorflow;
REGISTER_OP("ZeroOut")
.Input("to_zero: int32")
.Output("zeroed: int32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
});
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
class ZeroOutOp : public OpKernel {
public:
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the input tensor
Tensor input_tensor = context->mutable_input(0, true);
auto input = input_tensor.flat<int32>();
// We always return the input ref.
context->forward_ref_input_to_ref_output(0, 0);
// Set all but the first element of the output tensor to 0.
const int N = input.size();
for (int i = 1; i < N; i++) {
input(i) = 0;
}
}
};
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
If I compile the above code and run a simple Python
script to test it, I get the following error:
Python(14820,0x700003a7e000) malloc: *** error for object 0x7fd5c45a5a88: pointer being freed was not allocated
*** set a breakpoint in malloc_error_break to debug
What do I need to change in my code to accomplish my need?
I think you would better to modify the process of grabbing input and output. Actually according to your REGISTER_OP, it is not reference input, so
context->mutable_input(0, true)
would be
context->input(0)
Also, setting the output would be changed to
context->set_output(0, context->input(0))
I think it will work after setting the output.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.