简体   繁体   中英

compare tuple with array in them

I am learning python and numpy. While trying to predict a result to check my understanding, I come across this :

import numpy as np
x = np.arange(1,10).reshape(3,3)
np.where(x>5) == (np.array([1, 2, 2, 2]), np.array([2, 0, 1, 2]))

And the correlated error I now understand why.

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

And I came up with this:

all(map(np.array_equal,
  np.where(x>5), (np.array([1, 2, 2, 2]), np.array([2, 0, 1, 2])) ))

all([np.array_equal(v,t) for (v,t) in zip(
  np.where(x>5), (np.array([1, 2, 2, 2]), np.array([2, 0, 1, 2]))) ])

all([(v==t).all() for (v,t) in zip(
  np.where(x>5), (np.array([1, 2, 2, 2]), np.array([2, 0, 1, 2]))) ])

Which work, but seems to me a bit tedious and hard to read. Is there a more pythonic or numpy way to test arrays within a tuple ?

You were pretty close. The following works:

np.array_equal(np.where(x>5), (np.array([1, 2, 2, 2]), np.array([2, 0, 1, 2])))

np.array_equal has the ability to broadcast over tuples. It also treats arrays and sequences as equivalent, so you could just use if you'd prefer:

np.array_equal(np.where(x>5), ([1, 2, 2, 2], [2, 0, 1, 2]))

There's also a testing submodule (though I haven't used it much)

In [54]: import numpy.testing

In [59]: x = np.arange(1,10).reshape(3,3)
    ...: idx = np.where(x>5); test=(np.array([1, 2, 2, 2]), np.array([2, 0, 1, 2]))

passes:

In [60]: np.testing.assert_equal(idx,test)

a failure:

In [61]: idx = np.where(x>4)
In [62]: np.testing.assert_equal(idx,test)
Traceback (most recent call last):
  File "<ipython-input-62-a0326017ecb7>", line 1, in <module>
    np.testing.assert_equal(idx,test)
  File "/usr/local/lib/python3.8/dist-packages/numpy/testing/_private/utils.py", line 338, in assert_equal
    assert_equal(actual[k], desired[k], f'item={k!r}\n{err_msg}',
  File "/usr/local/lib/python3.8/dist-packages/numpy/testing/_private/utils.py", line 344, in assert_equal
    return assert_array_equal(actual, desired, err_msg, verbose)
  File "/usr/local/lib/python3.8/dist-packages/numpy/testing/_private/utils.py", line 932, in assert_array_equal
    assert_array_compare(operator.__eq__, x, y, err_msg=err_msg,
  File "/usr/local/lib/python3.8/dist-packages/numpy/testing/_private/utils.py", line 761, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Arrays are not equal
item=0

(shapes (5,), (4,) mismatch)
 x: array([1, 1, 2, 2, 2])
 y: array([1, 2, 2, 2])

it's docs:

Raises an AssertionError if two objects are not equal.

Given two objects (scalars, lists, tuples, dictionaries or numpy arrays),
check that all elements of these objects are equal. An exception is raised
at the first conflicting values.

transpose can turn both tuples into 2d arrays:

In [64]: np.argwhere(x>5)
Out[64]: 
array([[1, 2],
       [2, 0],
       [2, 1],
       [2, 2]])

In [66]: np.transpose(test)
Out[66]: 
array([[1, 2],
       [2, 0],
       [2, 1],
       [2, 2]])

Those can be compared if the shape matches:

In [68]: np.argwhere(x>5)==np.transpose(test)
Out[68]: 
array([[ True,  True],
       [ True,  True],
       [ True,  True],
       [ True,  True]])

But you have to care full when comparing arrays like this. The shapes have to match, otherwise you'll errors (same as if you tried to add them).

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM