[英]Generating indices of a 2D NumPy array
我想生成一個 2D numpy
數組,其中的元素是從它們的位置計算出來的。 類似於以下代碼:
import numpy as np
def calculate_element(i, j, other_parameters):
# do something
return value_at_i_j
def main():
arr = np.zeros((M, N)) # (M, N) is the shape of the array
for i in range(M):
for j in range(N):
arr[i][j] = calculate_element(i, j, ...)
這段代碼運行得非常慢,因為 Python 中的循環效率不高。 在這種情況下有沒有辦法更快地做到這一點?
順便說一句,現在我通過計算兩個二維“索引矩陣”來使用解決方法。 是這樣的:
def main():
index_matrix_i = np.array([range(M)] * N).T
index_matrix_j = np.array([range(N)] * M)
'''
index_matrix_i is like
[[0,0,0,...],
[1,1,1,...],
[2,2,2,...],
...
]
index_matrix_j is like
[[0,1,2,...],
[0,1,2,...],
[0,1,2,...],
...
]
'''
arr = calculate_element(index_matrix_i, index_matrix_j, ...)
Edit1:應用“索引矩陣”技巧后代碼變得更快,所以我想問的主要問題是是否有辦法不使用這個技巧,因為它需要更多 memory。簡而言之,我想有一個在時間和空間上都高效的解決方案。
Edit2:我測試的一些例子
# a simple 2D Gaussian
def calculate_element(i, j, i_mid, j_mid, i_sig, j_sig):
gaus_i = np.exp(-((i - i_mid)**2) / (2 * i_sig**2))
gaus_j = np.exp(-((j - j_mid)**2) / (2 * j_sig**2))
return gaus_i * gaus_j
# size of M, N
M, N = 1200, 4000
# use for loops to go through every element
# this code takes ~10 seconds
def main_1():
arr = np.zeros((M, N)) # (M, N) is the shape of the array
for i in range(M):
for j in range(N):
arr[i][j] = calculate_element(i, j, 600, 2000, 300, 500)
# print(arr)
plt.figure(figsize=(8, 5))
plt.imshow(arr, aspect='auto', origin='lower')
plt.show()
# use index matrices
# this code takes <1 second
def main_2():
index_matrix_i = np.array([range(M)] * N).T
index_matrix_j = np.array([range(N)] * M)
arr = calculate_element(index_matrix_i, index_matrix_j, 600, 2000, 300, 500)
# print(arr)
plt.figure(figsize=(8, 5))
plt.imshow(arr, aspect='auto', origin='lower')
plt.show()
您可以使用np.indices()
生成所需的 output:
例如,
np.indices((3, 4))
輸出:
[[[0 0 0 0]
[1 1 1 1]
[2 2 2 2]]
[[0 1 2 3]
[0 1 2 3]
[0 1 2 3]]]
矢量化numpy比我的 2 核機器上的普通 jitted numba更快
import numpy as np
import matplotlib.pyplot as plt
M, N = 1200, 4000
i = np.arange(M)
j = np.arange(N)
i_mid, j_mid, i_sig, j_sig = 600, 2000, 300, 500
arr = np.exp(-(i - i_mid)**2 / (2 * i_sig**2))[:,None] * np.exp(-(j - j_mid)**2 / (2 * j_sig**2))
# %timeit 100 loops, best of 5: 8.82 ms per loop
plt.figure(figsize=(8, 5))
plt.imshow(arr, aspect='auto', origin='lower')
plt.show()
並聯並行numba
import numba as nb # tested with numba 0.55.1
@nb.njit(parallel=True)
def calculate_element_nb(i, j, i_mid, j_mid, i_sig, j_sig):
res = np.empty((i,j), np.float32)
for i in nb.prange(res.shape[0]):
for j in range(res.shape[1]):
res[i,j] = np.exp(-(i - i_mid)**2 / (2 * i_sig**2)) * np.exp(-(j - j_mid)**2 / (2 * j_sig**2))
return res
M, N = 1200, 4000
calculate_element_nb(M, N, 600, 2000, 300, 500)
# %timeit 10 loops, best of 5: 80.4 ms per loop
plt.figure(figsize=(8, 5))
plt.imshow(calculate_element_nb(M, N, 600, 2000, 300, 500), aspect='auto', origin='lower')
plt.show()
您可以使用單個循環來填充多維列表,在完成所有元素后,它將像這樣轉換為np.array
:
import numpy as np
m, n = 5, 5
arr = []
for i in range(0, m*n, n):
arr.append(list(range(i, i+n)))
print(np.array(arr))
Output:
[[ 0 1 2 3 4]
[ 5 6 7 8 9]
[10 11 12 13 14]
[15 16 17 18 19]
[20 21 22 23 24]]
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.