简体   繁体   中英

how does numpy.where() method handle the equality condition when the array element and target are not the same data type

I have a long list and its element type is int. I want to find the index of element that equals to a certain number and I use np.where to achieve this.

The following is my original code,

# suppose x is [1, 1, 2, 3]
y = np.array(x, dtype=np.float32)
idx = list(np.where(y==1)[0])
# output is [0, 1]

After inspecting the code after some time, I realize that I should not use dtype=np.float32 because it would change the datatype of y to float. The correct code should be the following,

# suppose x is [1, 1, 2, 3]
y = np.array(x)
idx = list(np.where(y==1)[0])
# output is also [0, 1]

Surprisingly, these two code snippet produce exactly the same result.

my question

My does the condition for test of equality is handled in numpy.where when the datatype of array and target are not compatible (int vs float, eg)?

NumPy where (source code here ) is not concerned with the comparison of data types: its first argument is an array of bool type. When you write y == 1 , this is an array comparison operation which returns a Boolean array, which is then passed as an argument to where .

The relevant method is equal , which you implicitly invoke by writing y == 1 . Its documentation says:

What is compared are values, not types.

For example,

x, y, z = np.float64(0.25), np.float32(0.25), 0.25

These are all of different types, (numpy.float64, numpy.float32, float) but x == y and y == z and x == z are True. Here it is important that 0.25 is exactly represented in binary system (1/4).

With

x, y, z = np.float64(0.2), np.float32(0.2), 0.2

we see that x == y is False and y == z is False but x == z is True, because Python floats are 64-bit just like np.float64 . Since 1/5 is not exactly represented in binary, using 32 bits vs 64 bits results in two different approximations to 1/5, which is why equality fails: not because of types, but because np.float64(0.2) and np.float32(0.2) are actually different values (their difference is about 3e-9).

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM