简体   繁体   中英

numpy get column indices where all elements are greater than threshold

I want to find the column indices of a numpy array where all the elements of the column are greater than a threshold value.

For example,

 X = array([[ 0.16,  0.40,  0.61,  0.48,  0.20],
            [ 0.42,  0.79,  0.64,  0.54,  0.52],
            [ 0.64,  0.64,  0.24,  0.63,  0.43],
            [ 0.33,  0.54,  0.61,  0.43,  0.29],
            [ 0.25,  0.56,  0.42,  0.69,  0.62]])

In the above case, if the threshold is 0.4, my result should be 1,3.

您可以使用np.where与每列的min进行比较:

large = np.where(X.min(0) >= 0.4)[0]

a generic solution using list comprehension

threshold = 0.4
rows_nb, col_nb = shape(X)
rows_above_threshold = [col for col in range(col_nb) \
    if all([X[row][col] >= threshold for row in range(rows_nb)])]
x = array([[ 0.16,  0.40,  0.61,  0.48,  0.20],
        [ 0.42,  0.79,  0.64,  0.54,  0.52],
        [ 0.64,  0.64,  0.24,  0.63,  0.43],
        [ 0.33,  0.54,  0.61,  0.43,  0.29],
        [ 0.25,  0.56,  0.42,  0.69,  0.62]])

threshold = 0.3
size = numpy.shape(x)[0]
for it in range(size):
    y = x[it] > threshold
    print(y.all())

Try pls.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM