繁体   English   中英

PyTorch C++ 扩展:如何索引张量并更新它?

[英]PyTorch C++ extension: How to index tensor and update it?

我正在创建一个 PyTorch C++ 扩展,经过大量研究后,我无法弄清楚如何索引张量并更新其值。 我发现了如何使用data_ptr()方法迭代张量的条目,但这不适用于我的用例。

给定的是一个矩阵M ,一个索引对P的列表(块)列表和一个函数f: dtype(M)^2 -> dtype(M)^2它接受两个值并吐出两个新值。

我正在尝试实现以下伪代码:

for each block B in P:
    for each row R in M:
        for each index-pair (i,j) in B:
            M[R,i], M[R,j] = f(M[R,i], M[R,j])

毕竟,这段代码将使用 CUDA 在 GPU 上运行,但由于我对此没有任何经验,所以我想先编写一个纯 C++ 程序,然后对其进行转换。

任何人都可以建议如何执行此操作或如何将算法转换为执行等效操作?

我想要做的可以使用tensor.accessor<scalar_dtype, num_dimensions>()方法来完成。 如果在 GPU 上执行,请改用scalars.packed_accessor64<scalar_dtype, num_dimensions, torch::RestrictPtrTraits>()scalars.packed_accessor32<scalar_dtype, num_dimensions, torch::RestrictPtrTraits>() (取决于您的大小)。

auto num_rows = scalars.size(0);
matrix = torch::rand({10, 8});
auto a = matrix.accessor<float, 2>();
for (auto i = 0; i < num_rows; ++i) {
    auto x = a[i][some_index];
    auto new_x = some_function(x);
    a[i][some_index] = new_x;
}

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM