简体   繁体   中英

numpy.where with two-dimensional array

One can use numpy.where for selecting values from two arrays depending on a condition:

import numpy

a = numpy.random.rand(5)
b = numpy.random.rand(5)
c = numpy.where(a > 0.5, a, b)  # okay

If the array has more dimensions, however, this does not work anymore:

import numpy

a = numpy.random.rand(5, 2)
b = numpy.random.rand(5, 2)
c = numpy.where(a[:, 0] > 0.5, a, b)  # !
Traceback (most recent call last):
  File "p.py", line 10, in <module>
    c = numpy.where(a[:, 0] > 0.5, a, b)  # okay
  File "<__array_function__ internals>", line 6, in where
ValueError: operands could not be broadcast together with shapes (5,) (5,2) (5,2) 

I would have expected a numpy array of shape (5,2) .

What's the issue here? How to work around it?

Remember that broadcasting in numpy only works from the right, so while (5,) shaped arrays can broadcast with (2,5) shaped arrays they can't broadcast with (5,2) shaped arrays. to broadcast with a (5,2) shaped array you need to maintain the second dimension so that the shape is (5,1) (anything can broadcast with 1 )

Thus, you need to maintain the second dimension when indexing it (otherwise it removes the indexed dimension when only one value exists). You can do this by putting the index in a one-element list:

a = numpy.random.rand(5, 2)
b = numpy.random.rand(5, 2)
c = numpy.where(a[:, [0]] > 0.5, a, b) # works

You can use c = numpy.where(a > 0.5, a, b)

however if you want to use only the first column of a then you need to consider the shape of the output.

let's first see what is the shape of this operation

(a[:, 0] > 0.5).shape # outputs (5,)

it's one dimensional

while the shape of a and b is (5, 2)

it's two dimensional and hence you can't broadcast this

the solution is to reshape the mask operation to be of shape (5, 1)

your code should look like this

a = numpy.random.rand(5, 2)
b = numpy.random.rand(5, 2)
c = numpy.where((a[:, 0] > 0.5).reshape(-1, 1), a, b)  # !

You can try:

import numpy
a = numpy.random.rand(5, 2)
b = numpy.random.rand(5, 2)
c = numpy.where(a > 0.5, a, b)

instead of: c = np.where(a>0.5,a,b)

you can use: c = np.array([a,b])[a>0.5]

which works for multidimensional arrays if a and b have the same shape.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM