簡體   English   中英

Pytorch Tensor 如何獲取特定值的索引

[英]How Pytorch Tensor get the index of specific value

使用 python 列表,我們可以:

a = [1, 2, 3]
assert a.index(2) == 1

pytorch 張量如何直接找到.index()

我認為沒有從list.index()到 pytorch 函數的直接轉換。 但是,您可以使用tensor==number然后使用nonzero()函數獲得類似的結果。 例如:

t = torch.Tensor([1, 2, 3])
print ((t == 2).nonzero(as_tuple=True)[0])

這段代碼返回

1

[大小為1x1的torch.LongTensor]

對於多維張量,您可以執行以下操作:

(tensor == target_value).nonzero(as_tuple=True)

生成的張量的形狀為number_of_matches x tensor_dimension 例如,假設tensor是一個3 x 4張量(這意味着維度是 2),結果將是一個二維張量,其中包含行中匹配項的索引。

tensor = torch.Tensor([[1, 2, 2, 7], [3, 1, 2, 4], [3, 1, 9, 4]])
(tensor == 2).nonzero(as_tuple=False)
>>> tensor([[0, 1],
        [0, 2],
        [1, 2]])
x = torch.Tensor([11, 22, 33, 22])
print((x==22).nonzero().squeeze())

張量([1, 3])

可以通過轉換為 numpy 來完成,如下所示

import torch
x = torch.range(1,4)
print(x)
===> tensor([ 1.,  2.,  3.,  4.]) 
nx = x.numpy()
np.where(nx == 3)[0][0]
===> 2

根據其他人的回答:

t = torch.Tensor([1, 2, 3])
print((t==1).nonzero().item())

已經給出的答案很好,但是當我在沒有匹配的情況下嘗試時,它們無法處理。 為此,請參見:

def index(tensor: Tensor, value, ith_match:int =0) -> Tensor:
    """
    Returns generalized index (i.e. location/coordinate) of the first occurence of value
    in Tensor. For flat tensors (i.e. arrays/lists) it returns the indices of the occurrences
    of the value you are looking for. Otherwise, it returns the "index" as a coordinate.
    If there are multiple occurences then you need to choose which one you want with ith_index.
    e.g. ith_index=0 gives first occurence.

    Reference: https://stackoverflow.com/a/67175757/1601580
    :return:
    """
    # bool tensor of where value occurred
    places_where_value_occurs = (tensor == value)
    # get matches as a "coordinate list" where occurence happened
    matches = (tensor == value).nonzero()  # [number_of_matches, tensor_dimension]
    if matches.size(0) == 0:  # no matches
        return -1
    else:
        # get index/coordinate of the occurence you want (e.g. 1st occurence ith_match=0)
        index = matches[ith_match]
        return index

歸功於這個偉大的答案: https : //stackoverflow.com/a/67175757/1601580

在我看來,調用tolist()簡單易懂。

t = torch.Tensor([1, 2, 3])
t.tolist().index(2) # -> 1

用於在 1d 張量/數組示例中查找元素的索引

mat=torch.tensor([1,8,5,3])

找到 5 的索引

five=5

numb_of_col=4
for o in range(numb_of_col):
   if mat[o]==five:
     print(torch.tensor([o]))

要找到 2d/3d 張量的元素索引,請將其轉換為 1d #ie example.view(number of elements)

例子

mat=torch.tensor([[1,2],[4,3])
#to find index of 2

five = 2
mat=mat.view(4)
numb_of_col = 4
for o in range(numb_of_col):
   if mat[o] == five:
     print(torch.tensor([o]))    

對於浮點張量,我使用它來獲取張量中元素的索引。

print((torch.abs((torch.max(your_tensor).item()-your_tensor))<0.0001).nonzero())

這里我想獲取浮點張量中max_value的索引,你也可以像這樣放置你的值來獲取張量中任何元素的索引。

print((torch.abs((YOUR_VALUE-your_tensor))<0.0001).nonzero())
    import torch
    x_data = variable(torch.Tensor([[1.0], [2.0], [3.0]]))
    print(x_data.data[0])
    >>tensor([1.])

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM