[英]Numpy trim_zeros in 2D or 3D
How to remove leading / trailing zeros from a NumPy array?如何从 Z3B7F949B2343F9E5390E29F6EF5E1778Z 数组中删除前导/尾随零? Trim_zeros works only for 1D.
Trim_zeros仅适用于一维。
Here's some code that will handle 2-D arrays.这是一些将处理二维 arrays 的代码。
import numpy as np
# Arbitrary array
arr = np.array([
[0, 0, 0, 0, 0],
[0, 0, 0, 1, 0],
[0, 1, 1, 1, 0],
[0, 1, 0, 1, 0],
[1, 1, 0, 1, 0],
[1, 0, 0, 1, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]
])
nz = np.nonzero(arr) # Indices of all nonzero elements
arr_trimmed = arr[nz[0].min():nz[0].max()+1,
nz[1].min():nz[1].max()+1]
assert np.array_equal(arr_trimmed, [
[0, 0, 0, 1],
[0, 1, 1, 1],
[0, 1, 0, 1],
[1, 1, 0, 1],
[1, 0, 0, 1],
])
This can be generalized to N-dimensions as follows:这可以推广到 N 维,如下所示:
def trim_zeros(arr):
"""Returns a trimmed view of an n-D array excluding any outer
regions which contain only zeros.
"""
slices = tuple(slice(idx.min(), idx.max() + 1) for idx in np.nonzero(arr))
return arr[slices]
test = np.zeros((5,5,5,5))
test[1:3,1:3,1:3,1:3] = 1
trimmed_array = trim_zeros(test)
assert trimmed_array.shape == (2, 2, 2, 2)
assert trimmed_array.sum() == 2**4
The following function works for any dimension:以下 function 适用于任何尺寸:
def trim_zeros(arr, margin=0):
'''
Trim the leading and trailing zeros from a N-D array.
:param arr: numpy array
:param margin: how many zeros to leave as a margin
:returns: trimmed array
:returns: slice object
'''
s = []
for dim in range(arr.ndim):
start = 0
end = -1
slice_ = [slice(None)]*arr.ndim
go = True
while go:
slice_[dim] = start
go = not np.any(arr[tuple(slice_)])
start += 1
start = max(start-1-margin, 0)
go = True
while go:
slice_[dim] = end
go = not np.any(arr[tuple(slice_)])
end -= 1
end = arr.shape[dim] + min(-1, end+1+margin) + 1
s.append(slice(start,end))
return arr[tuple(s)], tuple(s)
Which can be tested with:可以通过以下方式进行测试:
test = np.zeros((3,4,5,6))
test[1,2,2,5] = 1
trim_zeros(test, margin=1)
I would like to extend the previous answers to n-dimension with ignore axis:我想用忽略轴将以前的答案扩展到 n 维:
def array_trim(arr, ignore=[],margin=0):
all = np.where(arr != 0)
idx = ()
for i in range(len(all)):
if i in ignore:
idx += (np.s_[:],)
else:
idx += (np.s_[all[i].min()-margin: all[i].max()+margin+1],)
return arr[idx]
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.