简体   繁体   中英

How to get back the image from Tensor in libtorch?

I am trying to get the image back from the tensor I created earlier and visualize it, However, the resulting image is distorted/garbage it seems. This is how I convert a CV_8UC3 into the corresponding at::Tensor :

at::Tensor tensor_image = torch::from_blob(img.data, { img.rows, img.cols, 3 }, at::kByte);

and this is how I convert it back to the image:

auto ToCvImage(at::Tensor tensor)
{
    int width = tensor.sizes()[0];
    int height = tensor.sizes()[1];
    try
    {
        cv::Mat output_mat(cv::Size{ height, width }, CV_8UC3, tensor.data_ptr<int>());
        return output_mat.clone();
    }
    catch (const c10::Error& e)
    {
        std::cout << "an error has occured : " << e.msg() << std::endl;
    }
    return cv::Mat(height, width, CV_8UC3);
}

This is how the original image looks like:

在此处输入图像描述

and this is what I get after conversion:

在此处输入图像描述

Now if I use at::kInt instead of kByte during the creation of the tensor:

at::Tensor tensor_image = torch::from_blob(img.data, { img.rows, img.cols, 3 }, at::kByte);

I no longer get the distorted image, however, the network output will be off which means something has gone wrong in the input!

What's the issue here and how should I be going about this?

When the tensor was created using a c10::kByte for the cast we need to use uchar and not char or uint , etc. so in order to get this fixed I only had to use uchar instead of int :

auto ToCvImage(at::Tensor tensor)
{
    int width = tensor.sizes()[0];
    int height = tensor.sizes()[1];
    try
    {
        cv::Mat output_mat(cv::Size{ height, width }, CV_8UC3, tensor.data_ptr<uchar>());
        return output_mat.clone();
    }
    catch (const c10::Error& e)
    {
        std::cout << "an error has occured : " << e.msg() << std::endl;
    }
    return cv::Mat(height, width, CV_8UC3);
}

Side note: In case you created your Tensor with any other type, make sure to use Tensor::totype() effectively and convert to the proper type before hand. That is before I feed this tensor to my network, eg I convert it to KFloat and then carry on! its an obvious point that may very well be neglected and cost you hours of debugging!

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM