簡體   English   中英

過濾numpy數組的行?

[英]Filter rows of a numpy array?

我希望將函數應用於numpy數組的每一行。 如果此函數的計算結果為true,我將保留該行,否則我將丟棄該行。 例如,我的功能可能是:

def f(row):
    if sum(row)>10: return True
    else: return False

我想知道是否有類似的東西:

np.apply_over_axes()

它將函數應用於numpy數組的每一行並返回結果。 我希望有類似的東西:

np.filter_over_axes()

這會將函數應用於numpy數組的每一行,並且只返回函數返回true的行。 有這樣的事嗎? 或者我應該只使用for循環?

理想情況下,您將能夠實現函數的矢量化版本並使用它來執行布爾索引 對於絕大多數問題,這是正確的解決方案。 Numpy提供了許多可以在各種軸上執行的函數以及所有基本操作和比較,因此大多數有用的條件都應該是可矢量化的。

import numpy as np

x = np.random.randn(20, 3)
x_new = x[np.sum(x, axis=1) > .5]

如果您完全確定不能執行上述操作,我建議使用list comprehension(或np.apply_along_axis )來創建一個np.apply_along_axis數組來索引。

def myfunc(row):
    return sum(row) > .5

bool_arr = np.array([myfunc(row) for row in x])
x_new = x[bool_arr]

這將以相對干凈的方式完成工作,但將比矢量化版本慢得多。 一個例子:

x = np.random.randn(5000, 200)

%timeit x[np.sum(x, axis=1) > .5]
# 100 loops, best of 3: 5.71 ms per loop

%timeit x[np.array([myfunc(row) for row in x])]
# 1 loops, best of 3: 217 ms per loop

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM