简体   繁体   English

两个矩形与 NumPy 的交集

[英]Intersection of two rectangles with NumPy

I have the following function to find the intersection of two rectangles.我有以下函数来查找两个矩形的交集。 It's a little bit slow, I don't know if it's due to the OR condition or the >, < operators.有点慢,不知道是OR条件还是>, <操作符的原因。 I wonder if there's a way to improve the performance of the is_intersect() function.我想知道是否有办法提高is_intersect()函数的性能。 Maybe with NumPy?也许用 NumPy? Or Cython?还是赛通?

import numpy as np

def is_intersect(rect1, rect2):
    xmin1, xmax1, ymin1, ymax1 = rect1
    xmin2, xmax2, ymin2, ymax2 = rect2
    if xmin1 > xmax2 or xmax1 < xmin2:
        return False
    if ymin1 > ymax2 or ymax1 < ymax2:
        return False
    return True

N_ELEMS = 100000000
rects1 = np.random.rand(N_ELEMS,4)
rects2 = np.random.rand(N_ELEMS,4)

temp_dct = dict()

for i in range(N_ELEMS):
    rect1 = rects1[i,:]
    rect2 = rects2[i,:]
    if is_intersect(rect1, rect2):
        temp_dct[i] = True

I can't profit from caching results as the points will be incremental, that is, one rectangle will move in space (never the same place).我不能从缓存结果中获利,因为点将是增量的,也就是说,一个矩形将在空间中移动(永远不会在同一个地方)。 In this example, I used NumPy's random() function, but that's not the case for my real use.在这个例子中,我使用了 NumPy 的random()函数,但我实际使用的情况并非如此。 I will call the is_intersect() function 100 000 000 times or more.我将调用is_intersect()函数 100 000 000 次或更多次。

You can improve performance by avoiding the for loop using vectorized comparison and np.any :您可以通过使用矢量化比较和np.any避免 for 循环来提高性能:

result = (1 - np.any([rects1[:,0] > rects2[:,1], 
                      rects1[:,1] < rects2[:,0], 
                      rects1[:,2] > rects2[:,3], 
                      rects1[:,3] < rects2[:,2]], 
                     axis=0)).astype(bool)

You don't get a dictionary, yet you can access result by index.你没有字典,但你可以通过索引访问result

Performance with 100M elements: 100M 元素的性能:

import numpy as np
import timeit

N_ELEMS = 100_000_000
rects1 = np.random.rand(N_ELEMS,4)
rects2 = np.random.rand(N_ELEMS,4)

start_time = timeit.default_timer()
result = (1 - np.any([rects1[:,0] > rects2[:,1], 
                      rects1[:,1] < rects2[:,0], 
                      rects1[:,2] > rects2[:,3], 
                      rects1[:,3] < rects2[:,2]], 
                     axis=0)).astype(bool)

print(timeit.default_timer() - start_time)
2.9162093999999996

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM