简体   繁体   中英

Numpy: How to check for tuples existence in ndarray

I found a strange behavior when working with tuples in numpy arrays. I want to get a table of booleans telling me which tuples in array a also exist in array b . Normally, I would use any of in , in1d . None of them work while tuple(a[1]) == b[1,1] yields True .

I fill my a and b like this:

a = numpy.array([(0,0)(1,1)(2,2)], dtype=tuple)

b = numpy.zeros((3,3), dtype=tuple)
for i in range(0,3):
    for j in range(0,3):
        b[i,j] = (i,j)

Can anyone tell me a solution to my problem and please enlighten me why this does not work as expected?

(Using python2.7 and numpy1.6.2 over here btw.)

Why this doesn't work

The short version is that numpy's implementation of array.__contains__() seems to be broken. The in operator in python calls __contains__() behind the scenes.

Meaning that a in b is equivalent to b.__contains__(a) .

I've loaded up your arrays in a REPL and try the following:

>>> b[:,0]
array([(0, 0), (1, 0), (2, 0)], dtype=object)
>>> (0,0) in b[:,0] # we expect it to be true
False
>>> (0,0) in list(b[:,0]) # this shouldn't be different from the above but it is
True
>>> 

How to fix it

I don't see how your list comprehension could work since a[x] is a tuple and b[:,:] is a 2D matrix so of course they're not equal. But I'm assuming you meant to use in instead of == . Do correct me if I'm wrong here and you meant something different that I'm just not seeing.

The first step is to convert b from a 2D array to a 1D array so we can sift through it linearly and convert it to a list to avoid numpy's broken array.__contains() like so:

bb = list(b.reshape(b.size))

Or, better yet, make it a set since tuples are immutable and checking for in in a set is O(1) instead of the list's O(n) behavior

>>> bb = set(b.reshape(b.size))
>>> print bb
set([(0, 1), (1, 2), (0, 0), (2, 1), (1, 1), (2, 0), (2, 2), (1, 0), (0, 2)])
>>> 

Next we simply use the list comprehension to derive the table of booleans

>>> truth_table = [tuple(aa) in bb for aa in a]
>>> print truth_table
[True, True, True]
>>> 

Full code:

def contained(a,b):
    bb = set(b.flatten())
    return [tuple(aa) in bb for aa in a]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM