[英]PyTorch how to do gathers over multiple dimensions
我正在嘗試找到一種無需 for 循環的方法。
假設我有一個多維張量t0
:
bs = 4
seq = 10
v = 16
t0 = torch.rand((bs, seq, v))
這有形狀: torch.Size([4, 10, 16])
我有另一個張量labels
,它是seq
維度中的一批 5 個隨機索引:
labels = torch.randint(0, seq, size=[bs, sample])
所以這有形狀torch.Size([4, 5])
。 這用於索引t0
的seq
維度。
我想要做的是循環使用labels
張量在批處理維度上進行收集。 我的蠻力解決方案是這樣的:
t1 = torch.empty((bs, sample, v))
for b in range(bs):
for idx0, idx1 in enumerate(labels[b]):
t1[b, idx0, :] = t0[b, idx1, :]
導致張量t1
的形狀為: torch.Size([4, 5, 16])
在 pytorch 中有更慣用的方法嗎?
您可以在此處使用花哨的索引來 select 張量的所需部分。
本質上,如果您事先生成索引 arrays 來傳達您的訪問模式,您可以直接使用它們來提取張量的一些切片。 每個維度的索引 arrays 的形狀應與您要提取的 output 張量或切片的形狀相同。
i = torch.arange(bs).reshape(bs, 1, 1) # shape = [bs, 1, 1]
j = labels.reshape(bs, sample, 1) # shape = [bs, sample, 1]
k = torch.arange(v) # shape = [v, ]
# Get result as
t1 = t0[i, j, k]
請注意上述 3 個張量的形狀。 廣播在張量的前面附加了額外的維度,因此基本上將k
重塑為[1, 1, v]
形狀,這使得它們中的所有 3 個都與元素操作兼容。
廣播后(i, j, k)
一起將產生 3 個[bs, sample, v]
形狀的 arrays 並且這些將(按元素)索引您的原始張量以產生形狀為[bs, sample, v]
的 output 張量t1
。
你可以這樣做:
t1 = t0[[[b] for b in range(bs)], labels]
或者
t1 = torch.stack([t0[b, labels[b]] for b in range(bs)])
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.