简体   繁体   中英

LibTorch C++: Converting Tensor back to Image, Result is 3x3 Grid?

I am working on some code that takes a matrix object (assume this behaves as cv::Mat), converts to tensor, does a forward pass through my model, and converts back to my matrix object. However I have one issue persisting. That is, my resulting matrix is a 3x3 grid of the result. I tested my conversion codes (Matrix to Tensor and back) by just passing the image through both, and the resulting image is correct. This leads me to believe it is something with how the forward pass creates your output tensor? How should I work with the output tensor to fix the problem?

Code for context

Matrix to Tensor:

int numel = rows * cols * depth;
assert(numel > 0);

tensor_image = torch::zeros({ rows, cols, depth }, torch::kFloat);

std::memcpy(tensor_image.data_ptr<float>(), Image.GetConstDataPtr(), sizeof(float) * numel);

tensor_image = tensor_image.permute({ 2, 0, 1 }).unsqueeze(0);

Forward Pass:

at::Tensor output = torch::zeros({ tensor_image.sizes()[0], tensor_image.sizes()[1], tensor_image.sizes()[2], tensor_image.sizes()[3] }, torch::kFloat);

device = torch::kCUDA;

// Move model to GPU
module.to(device);

// Move target to gpu 
tensor_image = tensor_image.to(device);
// Execute the model and turn its output into a tensor.
at::Tensor output = module.forward({ tensor_image }).toTensor();

output = output.to(torch::kCPU);

Tensor to Matrix:

int numel = height * width * depth;
assert(numel > 0);

Image.Resize(height, width, depth);

at::Tensor tensor_image_cpy = tensor_image.squeeze(0).permute({ 1, 2, 0 });

std::memcpy(Image.GetDataPtr(), tensor_image_cpy.data_ptr<float>(), sizeof(float) * numel);

I have found the solution to my problem. Citing yugo's answer to this question:

Convert pytorch tensor to opencv mat and vice versa in C++

You need to make sure to reshape the data:

tensor_image_cpy = tensor_image.reshape({ width * height * depth });

Hope this helps others who run into this issue.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM