简体   繁体   English

如何实现PyTorch中的多维花式索引?

[英]How to accomplish the multi-dimensional fancy indexing in PyTorch?

I have a 3D tensor ( [batch_size, seq_length, hidden_dim] ) and a 2D list ( [batch_size, seq_length] ).我有一个3D 张量[batch_size, seq_length, hidden_dim] )和一个二维列表[batch_size, seq_length] )。

I want to using list to accomplish the selection of this tensor.我想用 list 来完成这个张量的选择。

For example: the shape of 3D tensor t [2, 5, 3] and the shape of 2D list l [2, 5] .例如:3D 张量t [2, 5, 3]的形状和二维列表l [2, 5]的形状。

Let t0 = t[0, :, :] and l0 = l[0] .t0 = t[0, :, :]l0 = l[0] I would like to select " t0[l0] ".我想 select“ t0[l0] ”。

Same for t1[l1] , and so on. t1[l1]相同,依此类推。

I can only think of writng a for loop for achieve this:我只能想到编写一个for循环来实现这一点:

new_ts = [] 
for i in range(t.shape[0]):
    new_t = t[i][l[i]]
    new_ts.append(new_t) 
new_t = torch.cat(new_ts, dim=2)

There must be a more simple way to accomplish this.必须有一种更简单的方法来实现这一点。 I have also tried multi-dimensional fancy indexing t[l] , but the syntax is not valid and it doesn't work.我也尝试过多维花式索引t[l] ,但语法无效并且不起作用。

Looking forward to your suggestions.期待您的建议。

The code provided doesn't work, it seems you are looking to concatenate the tensors in the list on an axis that doesn't yet exit.提供的代码不起作用,您似乎希望将列表中的张量连接到尚未退出的轴上。 I suggest you use stack instead.我建议你改用stack

I don't think you can achieve this without a single loop.我认为没有一个循环就无法实现这一目标。 In terms of syntax and readability you can do a little better, in my opinion...在我看来,就语法和可读性而言,你可以做得更好......

Having both starting tensors t and l :同时具有起始张量tl

t = torch.rand(2,5,3)
l = torch.randint(0,5, size=(2,5))

Using a list comprehension:使用列表推导:

torch.stack([t[i][jj] for i, jj in enumerate(l)])

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM