[英]PyTorch how to do gathers over multiple dimensions
我正在尝试找到一种无需 for 循环的方法。
假设我有一个多维张量t0
:
bs = 4
seq = 10
v = 16
t0 = torch.rand((bs, seq, v))
这有形状: torch.Size([4, 10, 16])
我有另一个张量labels
,它是seq
维度中的一批 5 个随机索引:
labels = torch.randint(0, seq, size=[bs, sample])
所以这有形状torch.Size([4, 5])
。 这用于索引t0
的seq
维度。
我想要做的是循环使用labels
张量在批处理维度上进行收集。 我的蛮力解决方案是这样的:
t1 = torch.empty((bs, sample, v))
for b in range(bs):
for idx0, idx1 in enumerate(labels[b]):
t1[b, idx0, :] = t0[b, idx1, :]
导致张量t1
的形状为: torch.Size([4, 5, 16])
在 pytorch 中有更惯用的方法吗?
您可以在此处使用花哨的索引来 select 张量的所需部分。
本质上,如果您事先生成索引 arrays 来传达您的访问模式,您可以直接使用它们来提取张量的一些切片。 每个维度的索引 arrays 的形状应与您要提取的 output 张量或切片的形状相同。
i = torch.arange(bs).reshape(bs, 1, 1) # shape = [bs, 1, 1]
j = labels.reshape(bs, sample, 1) # shape = [bs, sample, 1]
k = torch.arange(v) # shape = [v, ]
# Get result as
t1 = t0[i, j, k]
请注意上述 3 个张量的形状。 广播在张量的前面附加了额外的维度,因此基本上将k
重塑为[1, 1, v]
形状,这使得它们中的所有 3 个都与元素操作兼容。
广播后(i, j, k)
一起将产生 3 个[bs, sample, v]
形状的 arrays 并且这些将(按元素)索引您的原始张量以产生形状为[bs, sample, v]
的 output 张量t1
。
你可以这样做:
t1 = t0[[[b] for b in range(bs)], labels]
或者
t1 = torch.stack([t0[b, labels[b]] for b in range(bs)])
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.