[英]Remove elements in an ndarray based on condition on one dimension
In a Numpy ndarray, how do I remove elements in a dimension based on condition in a different dimension?在 Numpy ndarray 中,如何根据不同维度中的条件删除维度中的元素?
I have:我有:
[[[1 3]
[1 4]]
[[2 6]
[2 8]]
[[3 5]
[3 5]]]
I want to remove based on condition x[:,:,1] < 7
我想根据条件x[:,:,1] < 7
删除
Desired output ( [:,1,:]
removed):所需的输出( [:,1,:]
已删除):
[[[1 3]
[1 4]]
[[3 5]
[3 5]]]
EDIT: fixed typo编辑:固定错字
This may work:这可能有效:
x[np.where(np.all(x[..., 1] < 7, axis=1)), ...]
yields产量
array([[[[1, 3],
[1, 4]],
[[3, 5],
[3, 5]]]])
You do get an extra dimension, but that's easy to remove:你确实得到了一个额外的维度,但这很容易删除:
np.squeeze(x[np.where(np.all(x[..., 1] < 7, axis=1)), ...])
Briefly how it works:简要说明它是如何工作的:
First the condition: x[..., 1] < 7
.首先是条件: x[..., 1] < 7
。
Then test if the condition is valid for all elements along the specific axis: np.all(x[..., 1] < 7, axis=1)
.然后测试条件是否对沿特定轴的所有元素都有效: np.all(x[..., 1] < 7, axis=1)
。
Then, use where
to grab the indices instead of an array of booleans: np.where(np.all(x[..., 1] < 7, axis=1))
.然后,使用where
来获取索引而不是布尔数组: np.where(np.all(x[..., 1] < 7, axis=1))
。
And insert those indices into the relevant dimension: x[np.where(np.all(x[..., 1] < 7, axis=1)), ...]
.并将这些索引插入相关维度: x[np.where(np.all(x[..., 1] < 7, axis=1)), ...]
。
As your desired output, you filter x
on axis=0.作为您想要的输出,您在轴 = 0 上过滤x
。 Therefore, you may try this way因此,您可以尝试这种方式
m = (x[:,:,1] < 7).all(1)
x_out = x[m,:,:]
Or simply或者干脆
x_out = x[m]
Out[70]:
array([[[1, 3],
[1, 4]],
[[3, 5],
[3, 5]]])
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.