简体   繁体   English

过滤numpy数组的行?

[英]Filter rows of a numpy array?

I am looking to apply a function to each row of a numpy array. 我希望将函数应用于numpy数组的每一行。 If this function evaluates to true I will keep the row, otherwise I will discard it. 如果此函数的计算结果为true,我将保留该行,否则我将丢弃该行。 For example, my function might be: 例如,我的功能可能是:

def f(row):
    if sum(row)>10: return True
    else: return False

I was wondering if there was something similar to: 我想知道是否有类似的东西:

np.apply_over_axes()

which applies a function to each row of a numpy array and returns the result. 它将函数应用于numpy数组的每一行并返回结果。 I was hoping for something like: 我希望有类似的东西:

np.filter_over_axes()

which would apply a function to each row of an numpy array and only return rows for which the function returned true. 这会将函数应用于numpy数组的每一行,并且只返回函数返回true的行。 Is there anything like this? 有这样的事吗? Or should I just use a for loop? 或者我应该只使用for循环?

Ideally, you would be able to implement a vectorized version of your function and use that to do boolean indexing . 理想情况下,您将能够实现函数的矢量化版本并使用它来执行布尔索引 For the vast majority of problems this is the right solution. 对于绝大多数问题,这是正确的解决方案。 Numpy provides quite a few functions that can act over various axes as well as all the basic operations and comparisons, so most useful conditions should be vectorizable. Numpy提供了许多可以在各种轴上执行的函数以及所有基本操作和比较,因此大多数有用的条件都应该是可矢量化的。

import numpy as np

x = np.random.randn(20, 3)
x_new = x[np.sum(x, axis=1) > .5]

If you are absolutely sure that you can't do the above, I would suggest using a list comprehension (or np.apply_along_axis ) to create an array of bools to index with. 如果您完全确定不能执行上述操作,我建议使用list comprehension(或np.apply_along_axis )来创建一个np.apply_along_axis数组来索引。

def myfunc(row):
    return sum(row) > .5

bool_arr = np.array([myfunc(row) for row in x])
x_new = x[bool_arr]

This will get the job done in a relatively clean way, but will be significantly slower than a vectorized version. 这将以相对干净的方式完成工作,但将比矢量化版本慢得多。 An example: 一个例子:

x = np.random.randn(5000, 200)

%timeit x[np.sum(x, axis=1) > .5]
# 100 loops, best of 3: 5.71 ms per loop

%timeit x[np.array([myfunc(row) for row in x])]
# 1 loops, best of 3: 217 ms per loop

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM