![](/img/trans.png)
[英]How to convert PyTorch tensor to C++ torch::Tensor vice versa?
[英]How to use ITK to convert PNG into tensor for PyTorch
我正在尝试运行 C++ PyTorch 框架并遇到以下问题。
我成功编写了 model 脚本,现在可以运行了。 现在我必须将png
图像输入 model。
我在互联网上发现了一个有类似问题的人,他的想法是使用ITK
模块读取 PNG 文件并将其转换为数组,然后将其转换为Tensor
。
PNG -> RGBPixel[] -> tensor
所以以下是我现在正在尝试的。
using PixelTyupe = itk::RGBPixel<unsinged char>;
const unsigned int Dimension = 3;
typedef itk::Image<PixelType, Dimension> ImageType;
typedef itk::ImageFileReader<ImageType> ReaderType;
typedef itk::ImageRegionIterator<ImageType> IteratorType;
typename ImageType::RegionType region = itk_img->GetLargestPossibleRegion();
const typename ImageType::SizeType size = region.GetSize();
int len = size[0] * size[1] * size[2]; // This ends up 1920 * 1080 * 1
PixelType rowdata[len];
int count = 0;
IteratorType iter(itk_img, itk_img->GetRequestedRegion());
// convert itk to array
for (iter.GoToBegin(); !iter.IsAtEnd(); ++iter) {
rowdata[count] = iter.Get();
count++;
} // count = 1920 * 1080
// convert array to tensor
tensor_img = torch::from_blob(rowdata, {3, (int)size[0], (int)size[1]}, torch::kShort). clone(); // Segmenation Fault
当我尝试打印日志数据时,它包含三个数字,例如84 85 83
,所以我认为 PNG 文件已成功读取。
但是,我无法让代码的最后一部分工作。 我需要的是3:1920:1080
张量,但我认为 function 不能正确理解这三个 RGBPixel 值。
除此之外,我不明白为什么将维度设置为 3。
我将不胜感激任何帮助。
您不需要维度 3, Dimension = 2
就足够了。 如果你需要的布局是RGBx1920x1080,那么PixelType* rowdata = itk_img->GetBufferPointer();
无需进一步处理即可获得该布局。 由于torch::from_blob
不拥有缓冲区的所有权,因此其他人试图使用.clone()
。 您也不必这样做,假设您将itk_img
保留在 scope 中,或者使用它的引用计数和deleter
乱七八糟。
崩溃可能来自说缓冲区有短像素( torch::kShort
),当它有uchar
时。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.