简体   繁体   中英

Select entry that satisfies condition

I'm using numPy and have the following structure:

self.P = np.zeros((self.nS, self.nA, self.nS))

One instance of this structure could be for example:

Pl = np.zeros((7,2,7))
Pl[0,0,1]=1
Pl[1,0,2]=1
Pl[2,0,3]=1
Pl[3,0,4]=1
Pl[4,0,5]=1
Pl[5,0,6]=0.9
Pl[5,0,5]=0.1
Pl[6,0,6]=1
Pl[0,1,0]=1
Pl[1,1,1]=0
Pl[1,1,0]=1
Pl[2,1,1]=1
Pl[3,1,2]=1
Pl[4,1,3]=1
Pl[5,1,4]=1
Pl[6,1,5]=1

Now what I want to do is that given a number e, select one entry where the assigned value is < e.

Another condition is that I know the first entry (nS or x in the example) but the two other can vary.

I tried implementing it this way:

self.P[self.P[x,:,:] < e]

But it gives me this error:

IndexError: boolean index did not match indexed array along dimension 0; dimension is 7 but corresponding boolean dimension is 2

Any help is really appreciated.

The issue with your current attempt is that your indexing your entire array using a boolean mask that is only the size of the slice you have selected, which results in your IndexError .

Check out the shapes for yourself:

>>> Pl.shape
(7, 2, 7)
>>> x = 2
>>> (Pl[x] < 5).shape
(2, 7)
>>> Pl[Pl[x] < 5]
IndexError: boolean index did not match indexed array along dimension 0; dimension is 7
but corresponding boolean dimension is 2

Instead, you want to only apply your boolean mask to the dimension you have selected:

print(Pl[x])

array([[0., 0., 0., 1., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0.]])

e = 0.5
Pl[x, Pl[x] < e]

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM