繁体   English   中英

在 PyTorch C++ 扩展中,如何访问张量中的单个元素并将其转换为标准 C++ 数据类型?

[英]In the PyTorch C++ extension, how can I access a single element in a tensor and convert it to a standard c++ datatype?

我正在为 pytorch 编写一个 c++ 扩展,其中我需要通过索引访问张量的元素,并且我还需要将元素转换为标准的 c++ 类型。 这是一个简短的例子。 假设我有一个二维张量a并且我需要访问a[i][j]并将其转换为浮点数。

#include <torch/extension.h>

float get(torch::Tensor a, int i, int j) {
    return a[i][j];
}

上面的内容被放入一个名为tensortest.cpp的文件中。 在另一个文件setup.py我写

from setuptools import setup, Extension
from torch.utils import cpp_extension

setup(name='tensortest',
      ext_modules=[cpp_extension.CppExtension('tensortest_cpp', ['tensortest.cpp'])],
      cmdclass={'build_ext': cpp_extension.BuildExtension})

当我运行python setup.py install编译器报告以下错误

running install
running bdist_egg
running egg_info
creating tensortest.egg-info
writing tensortest.egg-info/PKG-INFO
writing dependency_links to tensortest.egg-info/dependency_links.txt
writing top-level names to tensortest.egg-info/top_level.txt
writing manifest file 'tensortest.egg-info/SOURCES.txt'
/home/trisst/.local/lib/python3.8/site-packages/torch/utils/cpp_extension.py:335: UserWarning: Attempted to use ninja as the BuildExtension backend but we could not find ninja.. Falling back to using the slow distutils backend.
  warnings.warn(msg.format('we could not find ninja.'))
reading manifest file 'tensortest.egg-info/SOURCES.txt'
writing manifest file 'tensortest.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_ext
building 'tensortest_cpp' extension
creating build
creating build/temp.linux-x86_64-3.8
x86_64-linux-gnu-gcc -pthread -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -g -fstack-protector-strong -Wformat -Werror=format-security -Wdate-time -D_FORTIFY_SOURCE=2 -fPIC -I/home/user/.local/lib/python3.8/site-packages/torch/include -I/home/user/.local/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/home/user/.local/lib/python3.8/site-packages/torch/include/TH -I/home/user/.local/lib/python3.8/site-packages/torch/include/THC -I/usr/include/python3.8 -c tensortest.cpp -o build/temp.linux-x86_64-3.8/tensortest.o -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=tensortest_cpp -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++14
In file included from /home/user/.local/lib/python3.8/site-packages/torch/include/ATen/Parallel.h:149,
                 from /home/user/.local/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/utils.h:3,
                 from /home/user/.local/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/nn/cloneable.h:5,
                 from /home/user/.local/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/nn.h:3,
                 from /home/user/.local/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/all.h:7,
                 from /home/user/.local/lib/python3.8/site-packages/torch/include/torch/extension.h:4,
                 from tensortest.cpp:1:
/home/user/.local/lib/python3.8/site-packages/torch/include/ATen/ParallelOpenMP.h:84: warning: ignoring #pragma omp parallel [-Wunknown-pragmas]
   84 | #pragma omp parallel for if ((end - begin) >= grain_size)
      | 
tensortest.cpp: In function ‘float get(at::Tensor, int, int)’:
tensortest.cpp:4:15: error: cannot convert ‘at::Tensor’ to ‘float’ in return
    4 |  return a[i][j];
      |               ^
error: command 'x86_64-linux-gnu-gcc' failed with exit status 1

我能做什么?

已编辑

#include <torch/extension.h>

float get(torch::Tensor a, int i, int j) 
{
    return a[i][j].item<float>(); 
}

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM