[英]Indices of multiple elements in a numpy array
我有一个numpy数组和如下列表
y=np.array([[1],[2],[1],[3],[1],[3],[2],[2]])
x=[1,2,3]
我想返回一个数组的元组,每个数组包含yie中x的每个元素的索引
(array([[0,2,4]]),array([[1,6,7]]),array([[3,5]]))
是否可以矢量化方式(没有任何循环)完成此操作?
请尝试以下操作:
y = y.flatten()
[np.where(y == searchval)[0] for searchval in x]
一种解决方案是map
y = y.reshape(1,len(y))
map(lambda k: np.where(y==k)[-1], x)
[array([0, 2, 4]),
array([1, 6, 7]),
array([3, 5])]
性能合理。 对于100000行,
%timeit list(map(lambda k: np.where(y==k), x))
3.1 ms ± 113 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
对于这个小例子,字典方法实际上更快(然后是wheres):
dd = {i:[] for i in [1,2,3]}
for i,v in enumerate(y):
v=v[0]
if v in dd:
dd[v].append(i)
list(dd.values())
其他SO问题中也出现了这个问题。 已经提出了使用unique
和sort
替代方案,但是它们更加复杂且难以重新创建-不一定更快。
对于numpy
这不是一个理想的问题。 结果是数组列表或大小不同的列表,这是一个很好的线索,表明不可能使用简单的“矢量化”全数组解决方案。 如果速度足够重要,则可能需要查看numba
或cython
实现。
根据值的混合,不同的方法可能具有不同的相对时间。 唯一值很少,但是长子列表可能更喜欢使用重复where
方法。 带有短子列表的许多唯一值可能会喜欢在y
上迭代的方法。
您可以使用collections.defaultdict
后跟一个理解:
y = np.array([[1],[2],[1],[3],[1],[3],[2],[2]])
x = [1,2,3]
from collections import defaultdict
d = defaultdict(list)
for idx, item in enumerate(y.flat):
d[item].append(idx)
res = tuple(np.array(d[k]) for k in x)
(array([0, 2, 4]), array([1, 6, 7]), array([3, 5]))
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.