簡體   English   中英

在 Pytorch 中重復張量的特定列

[英]Repeat specific columns of a tensor in Pytorch

我有一個大小為mxn的 pytorch 張量X和一個長度為n的非負整數num_repeats列表(假設 sum(num_repeat)>0)。 在 forward() 方法中,我想創建一個大小為mx sum(num_repeats)的張量X_dup ,其中X i列重復num_repeats[i]次。 張量X_dup將在 forward() 方法的下游使用,因此梯度需要正確反向傳播。 我能想出的所有解決方案都需要就地操作(創建一個新的張量並通過迭代num_repeats填充它),但如果我理解正確,這將不會保留梯度(如果我錯了,請糾正我,我是新手到整個 Pytorch 的事情)。

如果您使用的是 PyTorch >= 1.1.0,您可以使用torch.repeat_interleave

repeat_tensor = torch.tensor(num_repeats).to(X.device, torch.int64)
X_dup = torch.repeat_interleave(X, repeat_tensor, dim=1)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM