[英]numpy multidimensional indexing and the function 'take'
在一周中的奇數天,我幾乎理解numpy中的多維索引。 Numpy有一個函數'take'似乎可以做我想要的東西,但有額外的好處,我可以控制如果索引超出范圍會發生什么具體來說,我有一個三維數組要求作為查找表
lut = np.ones([13,13,13],np.bool)
以及一個2x2的3長向量數組,作為表中的索引
arr = np.arange(12).reshape([2,2,3]) % 13
IIUC,如果我要寫lut[arr]
那么arr
被視為2x2x3數字數組,當它們被用作lut
索引時,它們每個都返回一個13x13數組。 這解釋了為什么lut[arr].shape is (2, 2, 3, 13, 13)
。
我可以通過寫作讓它做我想做的事
lut[ arr[:,:,0],arr[:,:,1],arr[:,:,2] ] #(is there a better way to write this?)
現在這三個術語就好像它們被壓縮以生成一個2x2的元組數組而lut[<tuple>]
從lut
生成一個元素。 最終結果是來自lut
的2x2條目數組,正如我想要的那樣。
我已經閱讀了'take'功能的文檔......
這個函數與“花式”索引(使用數組索引數組)的功能相同; 但是,如果您需要沿給定軸的元素,則可以更容易使用。
和
axis:int,可選
用於選擇值的軸。
也許是天真的,我認為設置axis=2
我會得到三個值用作3元組來執行查找但實際上
np.take(lut,arr).shape = (2, 2, 3)
np.take(lut,arr,axis=0).shape = (2, 2, 3, 13, 13)
np.take(lut,arr,axis=1).shape = (13, 2, 2, 3, 13)
np.take(lut,arr,axis=2).shape = (13, 13, 2, 2, 3)
所以我很清楚我不明白發生了什么。 任何人都可以告訴我如何實現我想要的東西嗎?
我們可以計算線性指數,然后使用np.take
-
np.take(lut, np.ravel_multi_index(arr.T, lut.shape)).T
如果您對替代方案持開放態度,我們可以將indices數組重新整形為2D
,轉換為元組,使用它索引到數據數組中,為我們提供1D
,可以將其重新形成為2D
-
lut[tuple(arr.reshape(-1,arr.shape[-1]).T)].reshape(arr.shape[:2])
樣品運行 -
In [49]: lut = np.random.randint(11,99,(13,13,13))
In [50]: arr = np.arange(12).reshape([2,2,3])
In [51]: lut[ arr[:,:,0],arr[:,:,1],arr[:,:,2] ] # Original approach
Out[51]:
array([[41, 21],
[94, 22]])
In [52]: np.take(lut, np.ravel_multi_index(arr.T, lut.shape)).T
Out[52]:
array([[41, 21],
[94, 22]])
In [53]: lut[tuple(arr.reshape(-1,arr.shape[-1]).T)].reshape(arr.shape[:2])
Out[53]:
array([[41, 21],
[94, 22]])
我們可以避免np.take
方法的雙重轉置,就像這樣 -
In [55]: np.take(lut, np.ravel_multi_index(arr.transpose(2,0,1), lut.shape))
Out[55]:
array([[41, 21],
[94, 22]])
推廣到通用維度的多維數組
這可以推廣到通用號的ndarray。 昏暗的,像這樣 -
np.take(lut, np.ravel_multi_index(np.rollaxis(arr,-1,0), lut.shape))
基於tuple-based
方法應該沒有任何改變。
這是一個相同的樣本運行 -
In [95]: lut = np.random.randint(11,99,(13,13,13,13))
In [96]: arr = np.random.randint(0,13,(2,3,4,4))
In [97]: lut[ arr[:,:,:,0] , arr[:,:,:,1],arr[:,:,:,2],arr[:,:,:,3] ]
Out[97]:
array([[[95, 11, 40, 75],
[38, 82, 11, 38],
[30, 53, 69, 21]],
[[61, 74, 33, 94],
[90, 35, 89, 72],
[52, 64, 85, 22]]])
In [98]: np.take(lut, np.ravel_multi_index(np.rollaxis(arr,-1,0), lut.shape))
Out[98]:
array([[[95, 11, 40, 75],
[38, 82, 11, 38],
[30, 53, 69, 21]],
[[61, 74, 33, 94],
[90, 35, 89, 72],
[52, 64, 85, 22]]])
最初的問題是嘗試在表中進行查找,但是有些索引超出范圍,我想在發生這種情況時控制行為。
import numpy as np
lut = np.ones((5,7,11),np.int) # a 3-dimensional lookup table
print("lut.shape = ",lut.shape ) # (5,7,11)
# valid points are in the interior with value 99,
# invalid points are on the faces with value 0
lut[:,:,:] = 0
lut[1:-1,1:-1,1:-1] = 99
# set up an array of indexes with many of them too large or too small
start = -35
arr = np.arange(start,2*11*3+start,1).reshape(2,11,3)
# This solution has the advantage that I can understand what is going on
# and so I can amend it if I need to
# split arr into tuples along axis=2
arrchannels = arr[:,:,0],arr[:,:,1],arr[:,:,2]
# convert into a flat array but clip the values
ravelledarr = np.ravel_multi_index(arrchannels, lut.shape, mode='clip')
# and now turn back into a list of numpy arrays
# (not an array of the original shape )
clippedarr = np.unravel_index( ravelledarr, lut.shape)
print(clippedarr[0].shape,"*",len(clippedarr)) # produces (2, 11) * 3
# and now I can do the lookup with the indexes clipped to fit
print(lut[clippedarr])
# these are more succinct but opaque ways of doing the same
# due to @Divakar and @hjpauli respectively
print( np.take(lut, np.ravel_multi_index(arr.T, lut.shape, mode='clip')).T )
print( lut.flat[np.ravel_multi_index(arr.T, lut.shape, mode='clip')].T )
實際的應用是我有一個rgb圖像,其中包含一些帶有一些標記的紋理木材,我已經確定了它的一塊。 我想獲取此補丁中的像素集,並在整個圖像中標記與其中一個匹配的所有點。 256x256x256存在表太大,所以我在補丁的像素上運行聚類算法,並為每個聚類設置存在表(補丁中的顏色通過rgb-或hsv-space形成細長線程,因此聚類周圍的框很小)。
我使存在表略大於需要,並用False填充每個面。
一旦我設置了這些小存在表,我現在可以通過查找表中的每個像素來測試圖像的其余部分以匹配補丁,並使用剪輯來制作通常不會映射到表中的像素實際映射到表的一個面(並獲得值'False')
我沒有嘗試三維。 但是在二維中我使用numpy.take獲得了我想要的結果:
np.take(np.take(T,ix,axis=0), iy,axis=1 )
也許您可以將其擴展為3維。
作為一個例子,我可以使用兩個1-dim數組作為索引ix和iy二維模板用於離散拉普拉斯方程,
ΔT = T[ix-1,iy] + T[ix+1, iy] + T[ix,iy-1] + T[ix,iy+1] - 4*T[ix,iy]
介紹更精簡的寫作:
def q(Φ,kx,ky):
return np.take(np.take(Φ,kx,axis=0), ky,axis=1 )
然后我可以用numpy.take運行以下python代碼 :
nx = 6; ny= 10
T = np.arange(nx*ny).reshape(nx, ny)
ix = np.linspace(1,nx-2,nx-2,dtype=int)
iy = np.linspace(1,ny-2,ny-2,dtype=int)
ΔT = q(T,ix-1,iy) + q(T,ix+1,iy) + q(T,ix,iy-1) + q(T,ix,iy+1) - 4.0 * q(T,ix,iy)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.