![](/img/trans.png)
[英]How to initialise a fixed-size ListArray in pyarrow from a numpy array efficiently?
[英]In numpy, how to efficiently list all fixed-size submatrices?
我有一個任意的NxM矩陣,例如:
1 2 3 4 5 6
7 8 9 0 1 2
3 4 5 6 7 8
9 0 1 2 3 4
我想得到這個矩陣中所有3x3子矩陣的列表:
1 2 3 2 3 4 0 1 2
7 8 9 ; 8 9 0 ; ... ; 6 7 8
3 4 5 4 5 6 2 3 4
我可以用兩個嵌套循環來做到這一點:
rows, cols = input_matrix.shape
patches = []
for row in np.arange(0, rows - 3):
for col in np.arange(0, cols - 3):
patches.append(input_matrix[row:row+3, col:col+3])
但對於大輸入矩陣,這很慢。 有沒有辦法用numpy更快地做到這一點?
我看過np.split
,但這給了我非重疊的子矩陣,而我想要所有可能的子矩陣,無論重疊。
你想要一個窗口視圖:
from numpy.lib.stride_tricks import as_strided
arr = np.arange(1, 25).reshape(4, 6) % 10
sub_shape = (3, 3)
view_shape = tuple(np.subtract(arr.shape, sub_shape) + 1) + sub_shape
arr_view = as_strided(arr, view_shape, arr.strides * 2
arr_view = arr_view.reshape((-1,) + sub_shape)
>>> arr_view
array([[[[1, 2, 3],
[7, 8, 9],
[3, 4, 5]],
[[2, 3, 4],
[8, 9, 0],
[4, 5, 6]],
...
[[9, 0, 1],
[5, 6, 7],
[1, 2, 3]],
[[0, 1, 2],
[6, 7, 8],
[2, 3, 4]]]])
這樣做的好處在於你不是在復制任何數據,只是以不同的方式訪問原始數組的數據。 對於大型陣列,這可以節省大量內存。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.