[英]Non-overlapping sliding window for 2D numpy array?
我正在尝试为 python 的二维数组创建一个不重叠的滑动窗口。 我的代码适用于一个小数组,但是当我将它放大到一个 455x 455 大小的窗口的 4552 x 4552 数组时,我收到以下错误: ValueError: array is too big; 'arr.size * arr.dtype.itemsize' is larger than the maximum possible size
ValueError: array is too big; 'arr.size * arr.dtype.itemsize' is larger than the maximum possible size
。 有什么建议?
import numpy as np
def rolling_window(a, shape, writeable=False): # rolling window for 2D array
s = (a.shape[0] - shape[0] + 1,) + (a.shape[1] - shape[1] + 1,) + shape
strides = a.strides + a.strides
allviews = np.lib.stride_tricks.as_strided(a, shape=s, strides=strides)
non_overlapping_views = allviews[0::2,0::3]
return non_overlapping_views
a = np.array([[0, 1, 2, 3, 4, 5],
[6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29],
[30, 31, 32, 33, 34, 35]], dtype=np.int)
shape = (3,3)
result = rolling_window(a, shape)
print result
[[[[ 0 1 2]
[ 6 7 8]
[12 13 14]]
[[ 3 4 5]
[ 9 10 11]
[15 16 17]]]
[[[12 13 14]
[18 19 20]
[24 25 26]]
[[15 16 17]
[21 22 23]
[27 28 29]]]]
在我看来,您的问题与卷积/最大池化操作直接相关。
我最近仅使用 numpy 为这些操作编写了幼稚的方法,因此我调整了其中一种方法来(希望)回答您的问题。
这是代码:
def rolling_window(input_array, size_kernel, stride, print_dims = True):
"""Function to get rolling windows.
Arguments:
input_array {numpy.array} -- Input, by default it only works with depth equals to 1.
It will be treated as a (height, width) image. If the input have (height, width, channel)
dimensions, it will be rescaled to two-dimension (height, width)
size_kernel {int} -- size of kernel to be applied. Usually 3,5,7. It means that a kernel of (size_kernel, size_kernel) will be applied
to the image.
stride {int or tuple} -- horizontal and vertical displacement
Keyword Arguments:
print_dims {bool} -- [description] (default: {True})
Returns:
[list] -- A list with the resulting numpy.arrays
"""
# Check right input dimension
assert(len(input_array.shape) in set([1,2])), "input_array must have dimension 2 or 3. Yours have dimension {}".format(len(input_array))
if input_array.shape == 3:
input_array = input_array[:,:,0]
# Stride: horizontal and vertical displacement
if isinstance(stride,int):
sh, sw = stride, stride
elif isinstance(stride,tuple):
sh, sw = stride
# Input dimension (height, width)
n_ah, n_aw = input_array.shape
# Filter dimension (or window)
n_k = size_kernel
dim_out_h = int(np.floor( (n_ah - n_k) / sh + 1 ))
dim_out_w = int(np.floor( (n_aw - n_k) / sw + 1 ))
# List to save output arrays
list_tensor = []
# Initialize row position
start_row = 0
for i in range(dim_out_h):
start_col = 0
for j in range(dim_out_w):
# Get one window
sub_array = input_array[start_row:(start_row+n_k), start_col:(start_col+n_k)]
# Append sub_array
list_tensor.append(sub_array)
start_col += sw
start_row += sh
if print_dims:
print("- Input tensor dimensions -- ", input_array.shape)
print("- Kernel dimensions -- ", (n_k, n_k))
print("- Stride (h,w) -- ", (sh, sw))
print("- Total windows -- ", len(list_tensor))
return list_tensor
要获得不重叠的滚动窗口,您只需将stride
和kernel_size
设置为相同的数字。 此外,这将打印一些信息,以及找到的窗口的长度。
一些例子:
1) 你的:
a = np.array([[0, 1, 2, 3, 4, 5],
[6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29],
[30, 31, 32, 33, 34, 35]], dtype=np.int)
size_kernel = 3 # Or can be set to 3 instead of a tuple
list_tensor = rolling_window(a, size_kernel, stride=3, print_dims = True)
##Output
#- Input tensor dimensions -- (6, 6)
#- Kernel dimensions -- (3, 3)
#- Stride (h,w) -- (3, 3)
#- Total windows -- 4
for array in list_array:
print('\n')
print(array)
## Output
[[ 0 1 2]
[ 6 7 8]
[12 13 14]]
[[ 3 4 5]
[ 9 10 11]
[15 16 17]]
[[18 19 20]
[24 25 26]
[30 31 32]]
[[21 22 23]
[27 28 29]
[33 34 35]]
(此外,与您的结果相比,我注意到您的代码为您提供了重叠的数据,您可以比较第二个和第四个输出。)
2) 带有窗口 (500,500) 的 (5000,5000) 数组
%%time
size = 5000
a = np.ones((size,size))
list_tensor = rolling_window(input_array = a, size_kernel = 500, stride=500, print_dims = True)
## Output
#- Input tensor dimensions -- (5000, 5000)
#- Kernel dimensions -- (500, 500)
#- Stride (h,w) -- (500, 500)
#- Total windows -- 100
#CPU times: user 68.7 ms, sys: 115 ms, total: 184 ms
#Wall time: 269 ms
我希望这有帮助!
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.