[英]Using numpy.where() or similar to get specifc values from a row in a matrix
As an MWE, I have a 2d numpy array: 作为MWE,我有一个2d的numpy数组:
import numpy as np
n1 = np.array([[6,7,8,1], [5,2,4,8], [3,4,2,1], [8,7,2,10]])
n1
array([[ 6, 7, 8, 1],
[ 5, 2, 4, 8],
[ 3, 4, 2, 1],
[ 8, 7, 2, 10]])
I want to get the indices where '6' and '8' occur in the first row; 我想获取第一行中出现“ 6”和“ 8”的索引; where '2' and '8' appear in the second row;
第二行中出现“ 2”和“ 8”; where '3' and '4' appear in the third row, and where '7' and '2' appear in the last row.
其中“ 3”和“ 4”出现在第三行中,而“ 7”和“ 2”出现在最后一行中。 That is, I have a list of numpy arrays:
也就是说,我有一个numpy数组的列表:
list1 = [np.array([6, 8]), np.array([2, 8]), np.array([3, 4]), np.array([7, 2])]
I want [n1[i, np.where(ar1)] for i in range(len(list1))]
or something like that to return a new list with the column where the value from list1 occurs: 我想要
[n1[i, np.where(ar1)] for i in range(len(list1))]
或类似的东西返回一个新列表,其中包含出现list1的值的列:
returned_list = [np.array([0, 2]), np.array([1, 3]), np.array([0, 1]), np.array([1, 2])]
Clearly I've tried [n1[i, np.where(ar1)] for i in range(len(list1))]
. 显然,我已经尝试过
[n1[i, np.where(ar1)] for i in range(len(list1))]
。 Any ideas? 有任何想法吗?
Here's one using NumPy broadcasting
for array output - 这是使用
NumPy broadcasting
进行数组输出的一种-
In [117]: arr_list1 = np.array(list1)
In [118]: mask = (n1[:,:,None] == arr_list1[:,None,:]).any(2)
In [119]: np.where(mask)[1].reshape(-1,2)
Out[119]:
array([[0, 2],
[1, 3],
[0, 1],
[1, 2]])
Explanation 说明
Basically we are extending both n1
and arr_list1
to 3D
arrays by introducing singleton dims/dims with lengths 1
with None/np.newaxis
, such that when compared against each other would result in elementwise comparison against the last axes from those two arrays as a full 3D
array. 基本上,我们通过引入长度为
1
的None/np.newaxis
单身arr_list1
/ arr_list1
来将n1
和arr_list1
到3D
数组,这样当彼此比较时,将与这两个数组中的最后一个轴进行元素比较3D
阵列。
We then look for ANY
match along the last axis which has a length of 2
corresponding to the two elems in arr_list1
for each row. 然后,我们沿着最后一个轴查找
ANY
匹配项,该匹配项的长度为2
对应于每一行的arr_list1
的两个arr_list1
。 That gives us a 2D array. 这给了我们2D阵列。 Finally, we need the matching row indices, so
np.where()[1]
. 最后,我们需要匹配的行索引,所以
np.where()[1]
。
Step-by-step run for a closer look - 分步进行仔细观察-
1) Inputs : 1)输入:
In [124]: n1
Out[124]:
array([[ 6, 7, 8, 1],
[ 5, 2, 4, 8],
[ 3, 4, 2, 1],
[ 8, 7, 2, 10]])
In [125]: arr_list1
Out[125]:
array([[6, 8],
[2, 8],
[3, 4],
[7, 2]])
2) Comparison : 2)比较:
In [126]: (n1[:,:,None] == arr_list1[:,None,:])
Out[126]:
array([[[ True, False],
[False, False],
[False, True],
[False, False]],
[[False, False],
[ True, False],
[False, False],
[False, True]],
[[ True, False],
[False, True],
[False, False],
[False, False]],
[[False, False],
[ True, False],
[False, True],
[False, False]]], dtype=bool)
3) ANY
reduction : 3)
ANY
减少:
In [127]: (n1[:,:,None] == arr_list1[:,None,:]).any(2)
Out[127]:
array([[ True, False, True, False],
[False, True, False, True],
[ True, True, False, False],
[False, True, True, False]], dtype=bool)
In [128]: mask = _
4) Finally get matching row indices : 4)最后得到匹配的行索引:
In [130]: np.where(mask)[1].reshape(-1,2)
Out[130]:
array([[0, 2],
[1, 3],
[0, 1],
[1, 2]])
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.