簡體   English   中英

基於參考 n-dim 數組對 n-dim 數組進行操作的最有效方法

[英]Most efficient way to operate on a n-dim array based on a reference n-dim array

我有兩個相同形狀的 numpy arrays : dat_araref_ara

我想在op_funcaxis = -1上執行dat_ara操作,但是我只想對每個數組中的選定值切片進行操作,當參考數組thres超過閾值 thres 時指定ref_ara

為了說明,在 arrays 只是 2-dim 的簡單情況下,我有:

thres = 4

op_func = np.average

ref_ara = array([[1, 2, 1, 4, 3, 5, 1, 5, 2, 5],
                 [1, 2, 2, 1, 1, 1, 2, 7, 5, 8],
                 [2, 3, 2, 5, 1, 6, 5, 2, 7, 3]]) 

dat_ara = array([[1, 0, 0, 1, 1, 1, 1, 0, 1, 1],
                 [1, 1, 1, 1, 1, 1, 1, 0, 1, 0],
                 [1, 0, 1, 1, 1, 1, 0, 1, 1, 1]]) 

我們看到 thres 在thresaxis=0中的第一個、第二個和第三個數組的第 5 個、第 7 個和第 3 個索引中被ref_ara 因此我想要的結果是

out_ara = array([op_func(array([1, 0, 0, 1, 1, 1]), 
                 op_func(array([1, 1, 1, 1, 1, 1, 1, 0]),
                 op_func(array([1, 0, 1, 1])])

這個問題很困難,因為它需要引用ref_ara 如果不是這樣,我可以簡單地使用numpy.apply_along_axis

我嘗試擴展兩個 arrays 的尺寸以將它們關聯起來進行計算,即:

assos_ara = np.append(np.expand_dims(dat_ara, axis=-1), np.expand_dims(ref_ara, axis=-1), axis=-1)

But again, numpy.apply_along_axis requires the input function to only operate on 1-dim arrays, and thus I still cannot utilise the function.

我知道的唯一其他方法是明智地遍歷 arrays 索引,但是,arrays 具有兩個 arrays 的不斷變化的尺寸,這不是計算效率高的問題,而是一個棘手的問題。

我很想使用矢量化函數來幫助這個過程。 go 最有效的方法是什么?

這是屏蔽 arrays 的一個很好的用例,因為它們允許您對部分數據執行正常的 numpy 操作。

假設每一行至少包含一個大於閾值的值。 您可以將斷點的索引計算為

breaks = np.argmax(ref_ara > thres, axis=-1)   # 5, 7, 3

然后,您可以使用我之前鏈接的問題答案創建一個掩碼。 掩碼通常是處理 numpy 中不規則形狀數據的最佳方式。

mask = np.arange(ref_ara.shape[-1]) <= breaks.reshape(*breaks.shape, 1)

在這里,我們不需要對arange做任何花哨的事情,因為它位於最后一個維度。 如果不是這種情況,您可能希望在范圍為 go 的中斷形狀中插入一個 1,並用 1 填充范圍形狀的尾部。

現在掩碼數組和 ufunc 解決方案略有不同。 掩碼數組版本更通用,所以它首先出現:

data = np.ma.array(data_ara, mask=~mask)

掩碼的 arrays 從正常的 boolean 索引的方式向后解釋掩碼,因此我們反轉掩碼。 或者,您可以使用>而不是<=來計算掩碼。 計算現在很簡單:

out_ara = np.ma.average(data, axis=-1).data

一個不太通用的替代方法是將您的操作分解為 ufunc,並使用它們提供的掩碼。 這對於np.average來說很容易,它只是np.sumnp.divide ,但對於更復雜的操作可能更難。

從 numpy 1.17.0 開始, np.sum有一個where關鍵字:

out_ara = np.sum(dat_ara, where=mask, axis=-1) / breaks

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM