簡體   English   中英

PyTorch 如何在多個維度上進行聚集

[英]PyTorch how to do gathers over multiple dimensions

我正在嘗試找到一種無需 for 循環的方法。

假設我有一個多維張量t0

bs = 4
seq = 10
v = 16
t0 = torch.rand((bs, seq, v))

這有形狀: torch.Size([4, 10, 16])

我有另一個張量labels ,它是seq維度中的一批 5 個隨機索引:

labels = torch.randint(0, seq, size=[bs, sample])

所以這有形狀torch.Size([4, 5]) 這用於索引t0seq維度。

我想要做的是循環使用labels張量在批處理維度上進行收集。 我的蠻力解決方案是這樣的:

t1 = torch.empty((bs, sample, v))
for b in range(bs):
    for idx0, idx1 in enumerate(labels[b]):
        t1[b, idx0, :] = t0[b, idx1, :]

導致張量t1的形狀為: torch.Size([4, 5, 16])

在 pytorch 中有更慣用的方法嗎?

您可以在此處使用花哨的索引來 select 張量的所需部分。

本質上,如果您事先生成索引 arrays 來傳達您的訪問模式,您可以直接使用它們來提取張量的一些切片。 每個維度的索引 arrays 的形狀應與您要提取的 output 張量或切片的形狀相同。

i = torch.arange(bs).reshape(bs, 1, 1) # shape = [bs, 1,      1]
j = labels.reshape(bs, sample, 1)      # shape = [bs, sample, 1]
k = torch.arange(v)                    # shape = [v, ]

# Get result as
t1 = t0[i, j, k]

請注意上述 3 個張量的形狀。 廣播在張量的前面附加了額外的維度,因此基本上將k重塑為[1, 1, v]形狀,這使得它們中的所有 3 個都與元素操作兼容。

廣播后(i, j, k)一起將產生 3 個[bs, sample, v]形狀的 arrays 並且這些將(按元素)索引您的原始張量以產生形狀為[bs, sample, v]的 output 張量t1

你可以這樣做:

t1 = t0[[[b] for b in range(bs)], labels]

或者

t1 = torch.stack([t0[b, labels[b]] for b in range(bs)])

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM