简体   繁体   中英

How can I speed up python's dictionary with Numba

I need to store a few cells in array of Boolean values. At first I used numpy, but when arrays started to take a lot of memory, I've got an idea to store non-zero elements in dictionary with tuples as keys (because it's hashable type). For emaxple: {(0, 0, 0): True, (1, 2, 3): True} (This is two cells in "3D array" with indices 0,0,0 and 1,2,3, but number of dimensions are unknown in advance and defined when I run my algorythm). It helped a lot, because non-zero cells fills just a small part of the array.

For writing and getting values from this dict I need to use loops:

def fill_cells(indices, area_dict):
    for i in indices:
        area_dict[tuple(i)] = 1

def get_cells(indices, area_dict):
    n = len(indices)
    out = np.zeros(n, dtype=np.bool)
    for i in range(n):
        out[i] = tuple(indices[i]) in area_dict.keys()
    return out

Now I need to speed up it with Numba. Numba doesn't support native Python's dict(), so I used numba.typed.Dict. The problem is that Numba want to know size of the keys in stage of defining fucntions, so I can't even create the dictionary (length of keys are unknown in advance and defined when I call the function):

@njit
def make_dict(n):
    out = {(0,)*n:True}
    return out

Numba can't infer the types of dictionary keys correctly and returns the error:

Compilation is falling back to object mode WITH looplifting enabled because Function "make_dict" failed type inference due to: Invalid use of Function(<built-in function mul>) with argument(s) of type(s): (tuple(int64 x 1), int64)

If I change n to concrete number in function, it works. I solved it with this trick:

n = 10
s = '@njit\ndef make_dict():\n\tout = {(0,)*%s:True}\n\treturn out' % n
exec(s)

But I think this is wrong inefficient way. And I steel need to use my fill_cells and get_cells function with @njit decorator, but Numba returns the same error because I'm trying to create tuple from numpy array in this functions.

I understand the fundamental limitations of Numba (and compilation in general), but maybe there is some way to speed up functions, or, maybe you have another solution to my cell-storing problem?

Final solution:

The main problem was that Numba need to know length of tuples when defining function that creates it. The trick is to redefine function each time. I needed to generate string with code that defines fucntion and run it with exec() :

n = 10
s = '@njit\ndef arr_to_tuple(a):\n\treturn (' + ''.join('a[%i],' % i for i in range(n)) + ')'
exec(s)

After that I can call arr_to_tuple(a) to create tuples that can be uses in another @njit - decorated functions.

For example, creating empty dictionary of tuple keys, that needed to solve the problem:

@njit
def make_empty_dict():
    tpl = arr_to_tuple(np.array([0]*5))
    out = {tpl:True}
    del out[tpl]
    return out

I write one element in dictionary because it's one of the ways of Numba to infer types.

Also, I need to use my fill_cells and get_cells function that described in the question. This is how I've rewritten them with Numba:

Writing elements. Just changed tuple() to arr_to_tuple():

@njit
def fill_cells_nb(indices, area_dict):
    for i in range(len(indices)):
        area_dict[arr_to_tuple(indices[i])] = True

Getting elements from dictionary required a bit of creepy code:

@njit
def get_cells_nb(indices, area_dict):
    n = len(indices)
    out = np.zeros(n, dtype=np.bool_)
    for i in range(n):
        new_len = len(area_dict)
        tpl = arr_to_tuple(indices[i])
        area_dict[tpl] = True
        old_len = len(area_dict)
        if new_len == old_len:
            out[i] = True
        else:
            del area_dict[tpl]
    return out

My version of Numba (0.46) doesn't support .contains (in) operator and try-except construction. If you have version that supports it you can write more "regular" solution for it.

So when I want to check if the element with some index exist in dictionary I memorize the length of it, then I write in dictionary something with mentioned index. If length changed I conclude that the element wasn't exist. Otherwise the element exists. Looks like very slow solution, but it's not.

Speed test:

Solutions work surprisingly fast. I tested it with %timeit in comparison to native-Python optimized code:

  1. arr_to_tuple() 5 times faster than regular tuple() function
  2. get_cells with numba 3 times faster for one element and 40 times faster for big arrays of elements compare to native-Python written get_cells
  3. fill_cells with numba 4 times faster for one element and 40 times faster for big arrays of elements compare to native-Python written fill_cells

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM