![](/img/trans.png)
[英]How to randomly set k elements in a dimension in pytorch tensor to a specific value?
[英]How to set minimum k elements in a dimension in pytorch tensor to a specific value?
例如,如果我有張量(形狀 [2, 3, 5])
[[[0.8823, 0.9150, 0.3829, 0.9593, 0.3904],
[0.6009, 0.2566, 0.7936, 0.9408, 0.1332],
[0.9346, 0.5936, 0.8694, 0.5677, 0.7411]],
[[0.4294, 0.8854, 0.5739, 0.2666, 0.6274],
[0.2696, 0.4414, 0.2969, 0.8317, 0.1053],
[0.2695, 0.3588, 0.1994, 0.5472, 0.0062]]]
並且 k = 2,我想將維度中的最小 k 個元素(例如,dim=2)設置為特定值(例如,5):
[[[0.8823, 0.9150, 5, 0.9593, 5],
[0.6009, 5, 0.7936, 0.9408, 5],
[0.9346, 5, 0.8694, 5, 0.7411]],
[[5, 0.8854, 0.5739, 5, 0.6274],
[5, 0.4414, 0.2969, 0.8317, 5],
[0.2695, 0.3588, 5, 0.5472, 5]]]
您可以提取每行k
最低的元素並用該張量掩蓋初始張量。 給定k=2
和v0=5
(替換k
最低元素的值):
>>> v, _ = x.sort(dim=2)
>>> v[:,:,k:k+1]
tensor([[[0.8823],
[0.6009],
[0.7411]],
[[0.5739],
[0.2969],
[0.2695]]])
執行切片x[:,:,k:k+1]
而不是使用x[:,:,k]
進行標准索引可以保持維數不變。
然后我們可以應用torch.where
:
>>> torch.where(x < v[:,:,k:k+1], v0, x)
tensor([[[0.8823, 0.9150, 5.0000, 0.9593, 5.0000],
[0.6009, 5.0000, 0.7936, 0.9408, 5.0000],
[0.9346, 5.0000, 0.8694, 5.0000, 0.7411]],
[[5.0000, 0.8854, 0.5739, 5.0000, 0.6274],
[5.0000, 0.4414, 0.2969, 0.8317, 5.0000],
[0.2695, 0.3588, 5.0000, 0.5472, 5.0000]]])
或者,您可以直接在原地重新分配 masked- x
上的值:
>>> x[x < v[:,:,k:k+1]] = v0
您可以使用torch.topk
和torch.Tensor.scatter_
的組合。
(因為torch.topk
返回max_top_k
而你想要min_top_k
。我們可以使用-1*all_num
來獲取min_top_k
)
val, ind = torch.topk(-a, k=2)
a.scatter_(index=ind, dim=-1, value=5)
print(a)
tensor([[[0.8823, 0.9150, 5.0000, 0.9593, 5.0000],
[0.6009, 5.0000, 0.7936, 0.9408, 5.0000],
[0.9346, 5.0000, 0.8694, 5.0000, 0.7411]],
[[5.0000, 0.8854, 0.5739, 5.0000, 0.6274],
[5.0000, 0.4414, 0.2969, 0.8317, 5.0000],
[0.2695, 0.3588, 5.0000, 0.5472, 5.0000]]])
輸入:
>>> a = torch.tensor([[[0.8823, 0.9150, 0.3829, 0.9593, 0.3904],
[0.6009, 0.2566, 0.7936, 0.9408, 0.1332],
[0.9346, 0.5936, 0.8694, 0.5677, 0.7411]],
[[0.4294, 0.8854, 0.5739, 0.2666, 0.6274],
[0.2696, 0.4414, 0.2969, 0.8317, 0.1053],
[0.2695, 0.3588, 0.1994, 0.5472, 0.0062]]])
>>> torch.topk(-a, k=2)
# values=tensor(
# [[[-0.3829, -0.3904],
# [-0.1332, -0.2566],
# [-0.5677, -0.5936]],
# [[-0.2666, -0.4294],
# [-0.1053, -0.2696],
# [-0.0062, -0.1994]]]),
# indices=tensor(
# [[[2, 4],
# [4, 1],
# [3, 1]],
# [[3, 0],
# [4, 0],
# [4, 2]]])
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.