简体   繁体   English

PyTorch 如何在多个维度上进行聚集

[英]PyTorch how to do gathers over multiple dimensions

I'm trying to find a way to do this without for loops.我正在尝试找到一种无需 for 循环的方法。

Say I have a multi-dimensional tensor t0 :假设我有一个多维张量t0

bs = 4
seq = 10
v = 16
t0 = torch.rand((bs, seq, v))

This has shape: torch.Size([4, 10, 16])这有形状: torch.Size([4, 10, 16])

I have another tensor labels that is a batch of 5 random indices in the seq dimension:我有另一个张量labels ,它是seq维度中的一批 5 个随机索引:

labels = torch.randint(0, seq, size=[bs, sample])

So this has shape torch.Size([4, 5]) .所以这有形状torch.Size([4, 5]) This is used to index the seq dimension of t0 .这用于索引t0seq维度。

What I want to do is loop over the batch dimension doing gathers using labels tensor.我想要做的是循环使用labels张量在批处理维度上进行收集。 My brute force solution is this:我的蛮力解决方案是这样的:

t1 = torch.empty((bs, sample, v))
for b in range(bs):
    for idx0, idx1 in enumerate(labels[b]):
        t1[b, idx0, :] = t0[b, idx1, :]

Resulting in tensor t1 which has shape: torch.Size([4, 5, 16])导致张量t1的形状为: torch.Size([4, 5, 16])

Is there a more idiomatic way of doing this in pytorch?在 pytorch 中有更惯用的方法吗?

You can use fancy indexing here to select the desired portion of the tensor.您可以在此处使用花哨的索引来 select 张量的所需部分。

Essentially, if you generate the index arrays conveying your access pattern beforehand, you can directly use them to extract some slice of the tensor.本质上,如果您事先生成索引 arrays 来传达您的访问模式,您可以直接使用它们来提取张量的一些切片。 The shape of the index arrays for each dimension should be same as that of the output tensor or slice you want to extract.每个维度的索引 arrays 的形状应与您要提取的 output 张量或切片的形状相同。

i = torch.arange(bs).reshape(bs, 1, 1) # shape = [bs, 1,      1]
j = labels.reshape(bs, sample, 1)      # shape = [bs, sample, 1]
k = torch.arange(v)                    # shape = [v, ]

# Get result as
t1 = t0[i, j, k]

Note the shapes of the above 3 tensors.请注意上述 3 个张量的形状。 Broadcasting appends extra dimesions in the front of a tensor, thus essentially reshaping k to [1, 1, v] shape which makes all 3 of them compatible for elementwise operations. 广播在张量的前面附加了额外的维度,因此基本上将k重塑为[1, 1, v]形状,这使得它们中的所有 3 个都与元素操作兼容。

After broadcasting (i, j, k) together will produce 3 [bs, sample, v] shaped arrays and those will (elementwise) index your original tensor to produce the output tensor t1 of shape [bs, sample, v] .广播后(i, j, k)一起将产生 3 个[bs, sample, v]形状的 arrays 并且这些将(按元素)索引您的原始张量以产生形状为[bs, sample, v]的 output 张量t1

You could do it like this:你可以这样做:

t1 = t0[[[b] for b in range(bs)], labels]

or或者

t1 = torch.stack([t0[b, labels[b]] for b in range(bs)])

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM