[英]What is an efficient way to replace specific columns with a subset in pytorch tensor
[英]Repeat specific columns of a tensor in Pytorch
我有一個大小為mxn
的 pytorch 張量X
和一個長度為n
的非負整數num_repeats
列表(假設 sum(num_repeat)>0)。 在 forward() 方法中,我想創建一個大小為mx sum(num_repeats)
的張量X_dup
,其中X
i
列重復num_repeats[i]
次。 張量X_dup
將在 forward() 方法的下游使用,因此梯度需要正確反向傳播。 我能想出的所有解決方案都需要就地操作(創建一個新的張量並通過迭代num_repeats
填充它),但如果我理解正確,這將不會保留梯度(如果我錯了,請糾正我,我是新手到整個 Pytorch 的事情)。
如果您使用的是 PyTorch >= 1.1.0,您可以使用torch.repeat_interleave
。
repeat_tensor = torch.tensor(num_repeats).to(X.device, torch.int64)
X_dup = torch.repeat_interleave(X, repeat_tensor, dim=1)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.