简体   繁体   中英

Select specific rows of 2D PyTorch tensor

Suppose I have a 2D tensor looking something like this:

[[44, 50, 1, 32],
.
.
.
[7, 13, 90, 83]]

and a list of row indices that I want to select that looks something like this [0, 34, 100, ..., 745] . How can I go through and create a new tensor that contains only the rows whose indices are contained in the array?

You could select like with numpy

import torch
x = torch.Tensor([[1, 2, 3, 4],
                  [5, 6, 7, 8],
                  [9, 8, 7, 6],
                  [5, 4, 2, 1]])

indices = [0, 3]

print(x[indices])
# tensor([[1., 2., 3., 4.],
#         [5., 4., 2., 1.]])

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM