簡體   English   中英

PyTorch:索引具有2D張量行索引的2D張量

[英]PyTorch: index 2D tensor with 2D tensor of row indices

我有一個形狀(x, n)的火炬張量a和另一個形狀(y, n)張量b ,其中y <= x 的每一列b包含行索引的序列對a和我想能夠做的是某種方式索引ab ,使得我獲得形狀的張量(y, n)其中第i列包含a[:, i][b[:, i]] (不太確定這是否是表達它的正確方法)。

這是一個例子(其中x = 5, y = 3和n = 4):

import torch

a = torch.Tensor(
    [[0.1, 0.2, 0.3, 0.4],
     [0.6, 0.7, 0.8, 0.9],
     [1.1, 1.2, 1.3, 1.4],
     [1.6, 1.7, 1.8, 1.9],
     [2.1, 2.2, 2.3, 2.4]]
)

b = torch.LongTensor(
    [[0, 3, 1, 2],
     [2, 2, 2, 0],
     [1, 1, 0, 4]]
)

# How do I get from a and b to c
# (so that I can also assign to those elements in a)?

c = torch.Tensor(
    [[0.1, 1.7, 0.8, 1.4],
     [1.1, 1.2, 1.3, 0.4],
     [0.6, 0.7, 0.3, 2.4]]
)

我無法理解這一點。 我正在尋找的是不會產生張量的方法c ,但也讓我分配相同形狀的張量c到的元素ac是由。

我嘗試使用index_select但它只支持索引的1-dim數組。

bt = b.transpose(0, 1)
at = a.transpose(0, 1)
ct = [torch.index_select(at[i], dim=0, index=bt[i]) for i in range(len(at))]
c  = torch.stack(ct).transpose(0, 1)
print(c)
"""
tensor([[0.1000, 1.7000, 0.8000, 1.4000],
        [1.1000, 1.2000, 1.3000, 0.4000],
        [0.6000, 0.7000, 0.3000, 2.4000]])
"""

它可能不是最好的解決方案,但希望這至少可以幫到你。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM