简体   繁体   English

将一维整数数组与二维数组中的行进行比较

[英]comparing a 1d array of integers to rows in 2d array

I have a piece of code that is running, but is currently a bottleneck, and I was wondering whether there is a smarter way to do it.我有一段代码正在运行,但目前是一个瓶颈,我想知道是否有更聪明的方法来做到这一点。

I have a 1D array of integers between 0-20, with a length of 20-1000 (it varies) and I'm trying to compare it to a set of 1D arrays that are stored in a 2D array.我有一个 0-20 之间的一维整数数组,长度为 20-1000(它会有所不同),我试图将它与一组存储在二维数组中的一维数组进行比较。 I wish to find any row in the 2D array that completely matches the 1D array.我希望在二维数组中找到与一维数组完全匹配的任何行。

My current approach to do this is the following:我目前这样做的方法如下:

res = np.mean(one_d_array == two_d_array,axis=1) == 1

The problem with this approach is that it will compare all elements in all rows, even if these rows don't match on the first element, second, ect... Which is of course very inefficient.这种方法的问题在于它会比较所有行中的所有元素,即使这些行在第一个元素、第二个元素等上不匹配......这当然是非常低效的。 I could remedy this by looping through the rows and comparing each row individually, then I would probably be able to stop the comparison as soon as one element is false.我可以通过遍历行并单独比较每一行来解决这个问题,然后我可能能够在一个元素为假时立即停止比较。 However then I would be stuck with a slow for loop, which would also not be ideal.但是,我会被一个缓慢的 for 循环卡住,这也不是理想的。

So I'm wondering is there some other clever way to get the best of both of these approaches?所以我想知道是否有其他一些聪明的方法可以充分利用这两种方法?

numpy has a few useful built-in functions for checking matrix/vector equality, this is about twice as fast: numpy 有一些有用的内置函数来检查矩阵/向量是否相等,这大约快两倍:

import numpy as np
import time
x = np.random.random((1, 1000))
y = np.random.random((10000, 1000))
y[53] = x

t = time.time()
x_in_y = np.equal(x, y).all(axis=1)  # equal(x, y) returns a row x col matrix of True for matches; all(axis=0) returns a vector len(rows) if the entire row in x == y is true
idx = np.where(x_in_y)  # returns the indicies where x_in_y is true (here 53)
print(time.time() - t)  # 0.019975900650024414

t = time.time()
res = np.mean(x == y, axis=1) == 1
print(time.time() - t)  # 0.03999614715576172

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM