繁体   English   中英

PyTorch 如何在多个维度上进行聚集

[英]PyTorch how to do gathers over multiple dimensions

我正在尝试找到一种无需 for 循环的方法。

假设我有一个多维张量t0

bs = 4
seq = 10
v = 16
t0 = torch.rand((bs, seq, v))

这有形状: torch.Size([4, 10, 16])

我有另一个张量labels ,它是seq维度中的一批 5 个随机索引:

labels = torch.randint(0, seq, size=[bs, sample])

所以这有形状torch.Size([4, 5]) 这用于索引t0seq维度。

我想要做的是循环使用labels张量在批处理维度上进行收集。 我的蛮力解决方案是这样的:

t1 = torch.empty((bs, sample, v))
for b in range(bs):
    for idx0, idx1 in enumerate(labels[b]):
        t1[b, idx0, :] = t0[b, idx1, :]

导致张量t1的形状为: torch.Size([4, 5, 16])

在 pytorch 中有更惯用的方法吗?

您可以在此处使用花哨的索引来 select 张量的所需部分。

本质上,如果您事先生成索引 arrays 来传达您的访问模式,您可以直接使用它们来提取张量的一些切片。 每个维度的索引 arrays 的形状应与您要提取的 output 张量或切片的形状相同。

i = torch.arange(bs).reshape(bs, 1, 1) # shape = [bs, 1,      1]
j = labels.reshape(bs, sample, 1)      # shape = [bs, sample, 1]
k = torch.arange(v)                    # shape = [v, ]

# Get result as
t1 = t0[i, j, k]

请注意上述 3 个张量的形状。 广播在张量的前面附加了额外的维度,因此基本上将k重塑为[1, 1, v]形状,这使得它们中的所有 3 个都与元素操作兼容。

广播后(i, j, k)一起将产生 3 个[bs, sample, v]形状的 arrays 并且这些将(按元素)索引您的原始张量以产生形状为[bs, sample, v]的 output 张量t1

你可以这样做:

t1 = t0[[[b] for b in range(bs)], labels]

或者

t1 = torch.stack([t0[b, labels[b]] for b in range(bs)])

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM