简体   繁体   中英

Optimise this function -- numpy broadcasting issue

I have a function contains that check for a given 2D array u if the box [min,max] contains each row of u . I need it to reshape u if needed, but the number of values of u will always be a multiple of d (can be zero);

I'm using the following snippet of code. This function run thousands of time. Can faster code be produced ? If you think so, any tips on how to ?

import numpy as np

def contains(u, min, max, dim, strict = True):
    u = np.array(u).reshape(-1 ,dim)
    if strict:
        return np.all((u > min) & (u < max), axis=1)
    else:
        return np.all((u >= min) & (u <= max), axis=1)

# Usage examples : 
d = 4
min = np.random.uniform(size=d)*1/2
max = np.random.uniform(size=d)*1/2+1/2
u1 = np.random.uniform(size=d)
u2 = np.random.uniform(size=(100,d))
u3 = u2[np.repeat(False,100)]

contains(u1,min,max,d) # should return a boolean array of shape (1,)
contains(u2,min,max,d) # shape (100,)
contains(u3,min,max,d) # shape (0,)

( EDITED : to fix the timing measurement issue raised by @max9111 in the comments, and to include a numexpr -modified solution).

The bottleneck would eventually be within the np.all() call. This could be sped up with Numba like the following:

import numpy as np
import numba as nb


@nb.jit(nopython=True)
def contains_nb(arr, a_arr, b_arr):
    m = a_arr.size
    arr = arr.reshape(-1, m)
    n = arr.shape[0]
    result = np.ones(n, dtype=np.bool8)
    for i in range(n):       
        for j in range(m):
            if not a_arr[j] < arr[i, j] < b_arr[j]:
                result[i] = False
                break
    return result

This is compared to the NumPy solution:

import numpy as np


def contains_np(arr, a_arr, b_arr):
    m = a_arr.size
    arr = arr.reshape(-1, m)
    return np.all((arr >= a_arr) & (arr <= b_arr), axis=1)

which I simplified a bit over your approach (I have omitted dim and strict parameters, since dim is redundant, as it can be inferred from a_arr or b_arr sizes, while the strict parameter does not add much to the solution, but it can be easily reintroduced). I also assume that the input is already always a NumPy array.

Also, the NumPy solution could be modified to use numexpr , which leads to a third approach. This will have some calling overhead, but could speed up the computations, eg:

import numpy as np
import numexpr as ne


def contains_ne(arr, a_arr, b_arr):
    m = a_arr.size
    arr = arr.reshape(-1, m)
    result = ne.evaluate('(arr >= a_arr) & (arr <= b_arr)')
    return np.all(result, axis=1)

The following benchmarks can be obtained:

体重

This show that the Numba solution is consistently the fastest. On the contrary, the use of numexpr seems to be non-beneficial for the range of parameters explored.

(full benchmark available here )

Try this to speed, read more here

from numba import jit

@jit(nopython=True)
def contains(u, min, max, dim, strict = True):
    u = np.array(u).reshape(-1 ,dim)
    if strict:
        return np.all((u > min) & (u < max), axis=1)
    else:
        return np.all((u >= min) & (u <= max), axis=1)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM