简体   繁体   English

计算numpy数组中有多少元素在每个其他元素的delta内

[英]count how many elements in a numpy array are within delta of every other element

consider the array x and delta variable d 考虑数组x和delta变量d

np.random.seed([3,1415])
x = np.random.randint(100, size=10)
d = 10

For each element in x , I want to count how many other elements in each are within delta d distance away. 对于x每个元素,我想计算每个元素中有多少其他元素在delta d 距离内。

So x looks like 所以x看起来像

print(x)

[11 98 74 90 15 55 13 11 13 26]

The results should be 结果应该是

[5 2 1 2 5 1 5 5 5 1]

what I've tried 我试过的
Strategy: 战略:

  • Use broadcasting to take the outer difference 使用广播来消除外部差异
  • Absolute value of outer difference 外部差异的绝对值
  • sum how many exceed threshold 总和有多少超过阈值

(np.abs(x[:, None] - x) <= d).sum(-1)

[5 2 1 2 5 1 5 5 5 1]

This works great. 这非常有效。 However, it doesn't scale. 但是,它没有扩展。 That outer difference is O(n^2) time. 外部差异是O(n ^ 2)时间。 How can I get the same solution that doesn't scale with quadratic time? 如何获得不能用二次时间缩放的相同解决方案?

Listed in this post are two more variants based on the searchsorted strategy from OP's answer post . 根据OP's answer post的搜索searchsorted strategy ,在这篇文章中列出了另外两个变种

def pir3(a,d):  # Short & less efficient
    sidx = a.argsort()
    p1 = a.searchsorted(a+d,'right',sorter=sidx)
    p2 = a.searchsorted(a-d,sorter=sidx)
    return p1 - p2

def pir4(a, d):   # Long & more efficient
    s = a.argsort()

    y = np.empty(s.size,dtype=np.int64)
    y[s] = np.arange(s.size)

    a_ = a[s]
    return (
        a_.searchsorted(a_ + d, 'right')
        - a_.searchsorted(a_ - d)
    )[y]

The more efficient approach derives the efficient idea to get s.argsort() from this post . 更有效的方法是从this post获得s.argsort()的有效想法。

Runtime test - 运行时测试 -

In [155]: # Inputs
     ...: a = np.random.randint(0,1000000,(10000))
     ...: d = 10


In [156]: %timeit pir2(a,d) #@ piRSquared's post solution
     ...: %timeit pir3(a,d)
     ...: %timeit pir4(a,d)
     ...: 
100 loops, best of 3: 2.43 ms per loop
100 loops, best of 3: 4.44 ms per loop
1000 loops, best of 3: 1.66 ms per loop

Strategy 战略

  • Since x is not necessarily sorted, we'll sort it and track the sorting permutation via argsort so we can reverse the permutation. 由于x不一定是排序的,我们将对它进行排序并通过argsort跟踪排序排列,以便我们可以反转排列。
  • We'll use np.searchsorted on x with x - d to find the starting place for when values of x start to exceed x - d . 我们将在xx - d上使用np.searchsorted来找到x值开始超过x - d时的起始位置。
  • Do it again on the other side except we'll have to use the np.searchsorted parameter side='right' and using x + d 在另一边再做一次,除了我们必须使用np.searchsorted参数side='right'并使用x + d
  • Take the difference between right and left searchsorts to calculate number of elements that are within +/- d of each element 使用右侧和左侧搜索类型之间的差异来计算每个元素的+/- d内的元素数量
  • Use argsort to reverse the sorting permutation 使用argsort来反转排序排列

define method presented in question as pir1 将问题定义为pir1

def pir1(a, d):
    return (np.abs(a[:, None] - a) <= d).sum(-1)

We'll define a new function pir2 我们将定义一个新函数pir2

def pir2(a, d):
    s = x.argsort()
    a_ = a[s]
    return (
        a_.searchsorted(a_ + d, 'right')
        - a_.searchsorted(a_ - d)
    )[s.argsort()]

demo 演示

pir1(x, d)

[5 2 1 2 5 1 5 5 5 1]    

pir1(x, d)

[5 2 1 2 5 1 5 5 5 1]    

timing 定时
pir2 is the clear winner! pir2是明显的赢家!

code

functions 功能

def pir1(a, d):
    return (np.abs(a[:, None] - a) <= d).sum(-1)

def pir2(a, d):
    s = x.argsort()
    a_ = a[s]
    return (
        a_.searchsorted(a_ + d, 'right')
        - a_.searchsorted(a_ - d)
    )[s.argsort()]

#######################
# From Divakar's post #
#######################
def pir3(a,d):  # Short & less efficient
    sidx = a.argsort()
    p1 = a.searchsorted(a+d,'right',sorter=sidx)
    p2 = a.searchsorted(a-d,sorter=sidx)
    return p1 - p2

def pir4(a, d):   # Long & more efficient
    s = a.argsort()

    y = np.empty(s.size,dtype=np.int64)
    y[s] = np.arange(s.size)

    a_ = a[s]
    return (
        a_.searchsorted(a_ + d, 'right')
        - a_.searchsorted(a_ - d)
    )[y]

test 测试

from timeit import timeit

results = pd.DataFrame(
    index=np.arange(1, 50),
    columns=['pir%s' %i for i in range(1, 5)])

for i in results.index:
    np.random.seed([3,1415])
    x = np.random.randint(1000000, size=i)
    for j in results.columns:
        setup = 'from __main__ import x, {}'.format(j)
        results.loc[i, j] = timeit('{}(x, 10)'.format(j), setup=setup, number=10000)

results.plot()

在此输入图像描述


extended out to larger arrays 扩展到更大的阵列
got rid of pir1 摆脱了pir1

from timeit import timeit

results = pd.DataFrame(
    index=np.arange(1, 11) * 1000,
    columns=['pir%s' %i for i in range(2, 5)])

for i in results.index:
    np.random.seed([3,1415])
    x = np.random.randint(1000000, size=i)
    for j in results.columns:
        setup = 'from __main__ import x, {}'.format(j)
        results.loc[i, j] = timeit('{}(x, 10)'.format(j), setup=setup, number=100)

results.insert(0, 'pir1', 0)

results.plot()

在此输入图像描述

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM