[英]output of numpy.where(condition) is not an array, but a tuple of arrays: why?
我正在嘗試numpy.where(condition[, x, y])
函數。
從numpy文檔中,我了解到,如果僅給出一個數組作為輸入,它應該返回該數組非零的索引(即“ True”):
如果僅給出條件,則返回元組condition.nonzero(),其中condition為True的索引。
但是,如果嘗試一下,它將返回一個包含兩個元素的元組 ,其中第一個是所需的索引列表,第二個是空元素:
>>> import numpy as np
>>> array = np.array([1,2,3,4,5,6,7,8,9])
>>> np.where(array>4)
(array([4, 5, 6, 7, 8]),) # notice the comma before the last parenthesis
所以問題是:為什么? 這種行為的目的是什么? 在什么情況下這很有用? 確實,要獲得所需的索引列表,我必須添加索引,如np.where(array>4)[0]
,這似乎是“丑陋的”。
附錄
我從一些答案中了解到,它實際上只是一個元素的元組。 仍然我不明白為什么要這樣輸出。 為了說明這是不理想的,請考慮以下錯誤(首先引起了我的問題):
>>> import numpy as np
>>> array = np.array([1,2,3,4,5,6,7,8,9])
>>> pippo = np.where(array>4)
>>> pippo + 1
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: can only concatenate tuple (not "int") to tuple
因此您需要做一些索引來訪問實際的索引數組:
>>> pippo[0] + 1
array([5, 6, 7, 8, 9])
在Python中, (1)
僅表示1
。 可以隨意將()
添加到組號和表達式中,以提高可讀性(例如(1+3)*3
v (1+3,)*3
)。 因此,要表示一個1元素元組,它使用(1,)
(並要求您也使用它)。
從而
(array([4, 5, 6, 7, 8]),)
是一個元素元組,該元素是一個數組。
如果將where
應用於2d數組,結果將是2元素元組。
where
的結果是可以將其直接插入索引槽中,例如
a[where(a>0)]
a[a>0]
應該返回相同的東西
就像
I,J = where(a>0) # a is 2d
a[I,J]
a[(I,J)]
或舉個例子:
In [278]: a=np.array([1,2,3,4,5,6,7,8,9])
In [279]: np.where(a>4)
Out[279]: (array([4, 5, 6, 7, 8], dtype=int32),) # tuple
In [280]: a[np.where(a>4)]
Out[280]: array([5, 6, 7, 8, 9])
In [281]: I=np.where(a>4)
In [282]: I
Out[282]: (array([4, 5, 6, 7, 8], dtype=int32),)
In [283]: a[I]
Out[283]: array([5, 6, 7, 8, 9])
In [286]: i, = np.where(a>4) # note the , on LHS
In [287]: i
Out[287]: array([4, 5, 6, 7, 8], dtype=int32) # not tuple
In [288]: a[i]
Out[288]: array([5, 6, 7, 8, 9])
In [289]: a[(i,)]
Out[289]: array([5, 6, 7, 8, 9])
======================
無論輸入數組的尺寸如何, np.flatnonzero
顯示了僅返回一個數組的正確方法。
In [299]: np.flatnonzero(a>4)
Out[299]: array([4, 5, 6, 7, 8], dtype=int32)
In [300]: np.flatnonzero(a>4)+10
Out[300]: array([14, 15, 16, 17, 18], dtype=int32)
它的文檔說:
這等效於a.ravel()。nonzero()[0]
實際上,這實際上就是函數的功能。
通過展平a
消除了如何處理多個尺寸的問題。 然后,它將響應從元組中刪除,從而為您提供一個簡單的數組。 通過展平,它對於一維數組並沒有特殊情況。
==========================
@Divakar建議np.argwhere
:
In [303]: np.argwhere(a>4)
Out[303]:
array([[4],
[5],
[6],
[7],
[8]], dtype=int32)
哪個做np.transpose(np.where(a>4))
或者,如果您不喜歡列向量,則可以再次轉置它
In [307]: np.argwhere(a>4).T
Out[307]: array([[4, 5, 6, 7, 8]], dtype=int32)
除了現在是1xn陣列。
我們也可以同樣都包裹where
在array
:
In [311]: np.array(np.where(a>4))
Out[311]: array([[4, 5, 6, 7, 8]], dtype=int32)
有多種方法可以將數組從where
元組中取出( [0]
, i,=
, transpose
, array
等)。
簡短的答案: np.where
設計為具有一致的輸出,而與數組的尺寸無關。
二維數組有兩個索引,因此np.where
的結果是一個包含相關索引的長度為2的元組。 這可以概括為3維的長度為3的元組,4維的長度為4的元組或N維的長度為N的元組。 通過此規則,很明顯在1維中,結果應為長度為1的元組。
只需使用np.asarray
函數。 在您的情況下:
>>> import numpy as np
>>> array = np.array([1,2,3,4,5,6,7,8,9])
>>> pippo = np.asarray(np.where(array>4))
>>> pippo + 1
array([[5, 6, 7, 8, 9]])
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.