簡體   English   中英

如何將 pytorch 張量中維度中的最小 k 個元素設置為特定值?

[英]How to set minimum k elements in a dimension in pytorch tensor to a specific value?

例如,如果我有張量(形狀 [2, 3, 5])

[[[0.8823, 0.9150, 0.3829, 0.9593, 0.3904],
  [0.6009, 0.2566, 0.7936, 0.9408, 0.1332],
  [0.9346, 0.5936, 0.8694, 0.5677, 0.7411]],

 [[0.4294, 0.8854, 0.5739, 0.2666, 0.6274],
  [0.2696, 0.4414, 0.2969, 0.8317, 0.1053],
  [0.2695, 0.3588, 0.1994, 0.5472, 0.0062]]]

並且 k = 2,我想將維度中的最小 k 個元素(例如,dim=2)設置為特定值(例如,5):

[[[0.8823, 0.9150, 5, 0.9593, 5],
  [0.6009, 5, 0.7936, 0.9408, 5],
  [0.9346, 5, 0.8694, 5, 0.7411]],

 [[5, 0.8854, 0.5739, 5, 0.6274],
  [5, 0.4414, 0.2969, 0.8317, 5],
  [0.2695, 0.3588, 5, 0.5472, 5]]]

您可以提取k最低的元素並用該張量掩蓋初始張量。 給定k=2v0=5 (替換k最低元素的值):

>>> v, _ = x.sort(dim=2)

>>> v[:,:,k:k+1]
tensor([[[0.8823],
         [0.6009],
         [0.7411]],

        [[0.5739],
         [0.2969],
         [0.2695]]])

執行切片x[:,:,k:k+1]而不是使用x[:,:,k]進行標准索引可以保持維數不變。

然后我們可以應用torch.where

>>> torch.where(x < v[:,:,k:k+1], v0, x)
tensor([[[0.8823, 0.9150, 5.0000, 0.9593, 5.0000],
         [0.6009, 5.0000, 0.7936, 0.9408, 5.0000],
         [0.9346, 5.0000, 0.8694, 5.0000, 0.7411]],

        [[5.0000, 0.8854, 0.5739, 5.0000, 0.6274],
         [5.0000, 0.4414, 0.2969, 0.8317, 5.0000],
         [0.2695, 0.3588, 5.0000, 0.5472, 5.0000]]])

或者,您可以直接在原地重新分配 masked- x上的值:

>>> x[x < v[:,:,k:k+1]] = v0

您可以使用torch.topktorch.Tensor.scatter_的組合。

(因為torch.topk返回max_top_k而你想要min_top_k 。我們可以使用-1*all_num來獲取min_top_k

val, ind = torch.topk(-a, k=2)
a.scatter_(index=ind, dim=-1, value=5)
print(a)

tensor([[[0.8823, 0.9150, 5.0000, 0.9593, 5.0000],
         [0.6009, 5.0000, 0.7936, 0.9408, 5.0000],
         [0.9346, 5.0000, 0.8694, 5.0000, 0.7411]],

        [[5.0000, 0.8854, 0.5739, 5.0000, 0.6274],
         [5.0000, 0.4414, 0.2969, 0.8317, 5.0000],
         [0.2695, 0.3588, 5.0000, 0.5472, 5.0000]]])

輸入:

>>> a = torch.tensor([[[0.8823, 0.9150, 0.3829, 0.9593, 0.3904],
                      [0.6009, 0.2566, 0.7936, 0.9408, 0.1332],
                      [0.9346, 0.5936, 0.8694, 0.5677, 0.7411]],
                  
                     [[0.4294, 0.8854, 0.5739, 0.2666, 0.6274],
                      [0.2696, 0.4414, 0.2969, 0.8317, 0.1053],
                      [0.2695, 0.3588, 0.1994, 0.5472, 0.0062]]])


>>> torch.topk(-a, k=2)
# values=tensor(
#     [[[-0.3829, -0.3904],
#       [-0.1332, -0.2566],
#       [-0.5677, -0.5936]],

#      [[-0.2666, -0.4294],
#       [-0.1053, -0.2696],
#       [-0.0062, -0.1994]]]),

# indices=tensor(
#     [[[2, 4],
#       [4, 1],
#       [3, 1]],

#      [[3, 0],
#       [4, 0],
#       [4, 2]]])

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM