简体   繁体   English

NumPy 将函数应用于与另一个 numpy 数组对应的行组

[英]NumPy apply function to groups of rows corresponding to another numpy array

I have a NumPy array with each row representing some (x, y, z) coordinate like so:我有一个 NumPy 数组,每一行代表一些 (x, y, z) 坐标,如下所示:

a = array([[0, 0, 1],
           [1, 1, 2],
           [4, 5, 1],
           [4, 5, 2]])

I also have another NumPy array with unique values of the z-coordinates of that array like so:我还有另一个 NumPy 数组,该数组具有该数组的 z 坐标的唯一值,如下所示:

b = array([1, 2])

How can I apply a function, let's call it "f", to each of the groups of rows in a which correspond to the values in b?我如何将一个函数(我们称之为“f”)应用到 a 中与 b 中的值相对应的每组行? For example, the first value of b is 1 so I would get all rows of a which have a 1 in the z-coordinate.例如,b 的第一个值是 1,所以我会得到 a 中 z 坐标为 1 的所有行。 Then, I apply a function to all those values.然后,我将一个函数应用于所有这些值。

In the end, the output would be an array the same shape as b.最后,输出将是一个与 b 形状相同的数组。

I'm trying to vectorize this to make it as fast as possible.我正在尝试对其进行矢量化以使其尽可能快。 Thanks!谢谢!

Example of an expected output (assuming that f is count()):预期输出示例(假设 f 是 count()):

c = array([2, 2])

because there are 2 rows in array a which have az value of 1 in array b and also 2 rows in array a which have az value of 2 in array b.因为数组 a 中有 2 行,数组 b 中的 az 值为 1,数组 a 中有 2 行,数组 b 中的 az 值为 2。

A trivial solution would be to iterate over array b like so:一个简单的解决方案是像这样迭代数组 b :

for val in b:
    apply function to a based on val
    append to an array c

My attempt:我的尝试:

I tried doing something like this, but it just returns an empty array.我尝试做这样的事情,但它只返回一个空数组。

func(a[a[:, 2]==b])

The problem is that the groups of rows with the same Z can have different sizes so you cannot stack them into one 3D numpy array which would allow to easily apply a function along the third dimension.问题是具有相同 Z 的行组可以具有不同的大小,因此您不能将它们堆叠成一个 3D numpy 数组,这将允许沿第三维轻松应用函数。 One solution is to use a for-loop, another is to use np.split :一种解决方案是使用 for 循环,另一种是使用np.split

a = np.array([[0, 0, 1],
              [1, 1, 2],
              [4, 5, 1],
              [4, 5, 2],
              [4, 3, 1]])


a_sorted = a[a[:,2].argsort()]

inds = np.unique(a_sorted[:,2], return_index=True)[1]

a_split = np.split(a_sorted, inds)[1:]

# [array([[0, 0, 1],
#         [4, 5, 1],
#         [4, 3, 1]]),

#  array([[1, 1, 2],
#         [4, 5, 2]])]

f = np.sum  # example of a function

result = list(map(f, a_split))
# [19, 15]

But imho the best solution is to use pandas and groupby as suggested by FBruzzesi.但恕我直言,最好的解决方案是按照 FBruzzesi 的建议使用 pandas 和 groupby。 You can then convert the result to a numpy array.然后,您可以将结果转换为 numpy 数组。

EDIT : For completeness, here are the other two solutions编辑:为了完整起见,这里是另外两个解决方案

List comprehension:列表理解:

b = np.unique(a[:,2])
result = [f(a[a[:,2] == z]) for z in b]

Pandas:熊猫:

df = pd.DataFrame(a, columns=list('XYZ'))
result = df.groupby(['Z']).apply(lambda x: f(x.values)).tolist()

This is the performance plot I got for a = np.random.randint(0, 100, (n, 3)) :这是我为a = np.random.randint(0, 100, (n, 3))得到的性能图:

在此处输入图片说明

As you can see, approximately up to n = 10^5 the "split solution" is the fastest, but after that the pandas solution performs better.如您所见,大约n = 10^5 ,“拆分解决方案”是最快的,但在此之后,pandas 解决方案的性能更好。

If you are allowed to use pandas:如果你被允许使用熊猫:

import pandas as pd
df=pd.DataFrame(a, columns=['x','y','z'])

df.groupby('z').agg(f)

Here f can be any custom function working on grouped data.这里f可以是处理分组数据的任何自定义函数。

Numeric example:数字示例:

a = np.array([[0, 0, 1],
              [1, 1, 2],
              [4, 5, 1],
              [4, 5, 2]])
df=pd.DataFrame(a, columns=['x','y','z'])
df.groupby('z').size()

z
1    2
2    2
dtype: int64

Remark that .size is the way to count number of rows per group.请注意, .size是计算每组行数的方法。

To keep it into pure numpy, maybe this can suit your case:为了保持纯麻木,也许这可以适合您的情况:

tmp = np.array([a[a[:,2]==i] for i in b])
tmp 
array([[[0, 0, 1],
        [4, 5, 1]],

       [[1, 1, 2],
        [4, 5, 2]]])

which is an array with each group of arrays.这是一个包含每组数组的数组。

c = np.array([])
for x in np.nditer(b):
    c = np.append(c, np.where((a[:,2] == x))[0].shape[0])

Output:输出:

[2. 2.]

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM