简体   繁体   中英

understanding the numpy where function

I have been trying to understand the numpy where function but not getting anywhere. I can understand simple comparisons such as where value > otherValue, but this example, from the documentation is not becomming clearer.

I would appreciate an easy to understand breakdown of this. thanks for any help provided:

>>> np.where([[True, False], [True, True]],
...          [[1, 2], [3, 4]],
...          [[9, 8], [7, 6]])
array([[1, 8],
       [3, 4]])

The where() function accepts 3 arguments. Condition, x and y . And as it's stated in documentation , if both x and y are specified, the output array contains elements of x where condition is True , and elements from y elsewhere.

In your case for first row it selects 1 from x and 8 from y (because of False) and for second row since both are True it selects them from x .

np.where([[True, False], [True, True]],
         [[1,    2],     [3,    4]],
         [[9,    8],     [7,    6]])

   array([[1, 8],
          [3, 4]])

I think we can make things simpler and focus on the np.where other than the nested list.

np.where([True, False, True, True],
         [1,    2,     3,    4],
         [9,    8,     7,    6])
Out[4]: array([1, 8, 3, 4])

I thought you can get the point from this simple equivalent. Simply put, it just selects the corresponding element from the first list( [1, 2, 3, 4] ) where the condition is True and the second list( [9, 8, 7, 6] ) where the condition is False.

The first condition is True then we choose 1(from the first list in the corresponding position), the second is False we choose 8(from the second list in the corresponding position) and so on and so forth.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM