[英]Iterate over last axis of a numpy array
假設我們有一個 (20, 5) 數組。 我們可以非常python地迭代每一行:
import numpy as np
xs = np.array(range(100)).reshape(20, 5)
for x in xs:
print(x)
如果我們想迭代另一個軸(在這個例子中,迭代列,但我正在為 ndarray 中的每個可能的軸尋找一個解決方案),它不那么直接,我們可以使用 Iterating over任意維度的方法numpy.array :
for i in range(xs.shape[-1]):
x = xs[..., i]
print(x)
是否有更直接的方法來迭代另一個軸,例如(偽代碼):
for x in xs.iterator(axis=-1):
print(x)
?
我認為stride 技巧模塊中的as_strided應該在這里完成工作。
它在數組中創建一個視圖而不是副本(如文檔所述)。
以下是as_stided
功能的簡單演示:
from numpy.lib.stride_tricks import as_strided
import numpy as np
xs = np.array(range(3 *3 * 4)).reshape(3,3, 4)
for x in xs:
print(x)
輸出:
[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
[[12 13 14 15]
[16 17 18 19]
[20 21 22 23]]
[[24 25 26 27]
[28 29 30 31]
[32 33 34 35]]
迭代數組特定軸的函數:
def iterate_over_axis(arr, axis=0):
strides = arr.strides
strides_ = [strides[axis], *strides[0:axis], *strides[(axis+1):]]
shape = arr.shape
shape_ = [shape[axis], *shape[0:axis], *shape[(axis+1):]]
return as_strided(arr, strides=strides_, shape=shape_)
for x in iterate_over_axis(xs, axis=1):
print(x)
輸出:
[[ 0 1 2 3]
[12 13 14 15]
[24 25 26 27]]
[[ 4 5 6 7]
[16 17 18 19]
[28 29 30 31]]
[[ 8 9 10 11]
[20 21 22 23]
[32 33 34 35]]
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.