[英]Generating indices of a 2D NumPy array
我想生成一个 2D numpy
数组,其中的元素是从它们的位置计算出来的。 类似于以下代码:
import numpy as np
def calculate_element(i, j, other_parameters):
# do something
return value_at_i_j
def main():
arr = np.zeros((M, N)) # (M, N) is the shape of the array
for i in range(M):
for j in range(N):
arr[i][j] = calculate_element(i, j, ...)
这段代码运行得非常慢,因为 Python 中的循环效率不高。 在这种情况下有没有办法更快地做到这一点?
顺便说一句,现在我通过计算两个二维“索引矩阵”来使用解决方法。 是这样的:
def main():
index_matrix_i = np.array([range(M)] * N).T
index_matrix_j = np.array([range(N)] * M)
'''
index_matrix_i is like
[[0,0,0,...],
[1,1,1,...],
[2,2,2,...],
...
]
index_matrix_j is like
[[0,1,2,...],
[0,1,2,...],
[0,1,2,...],
...
]
'''
arr = calculate_element(index_matrix_i, index_matrix_j, ...)
Edit1:应用“索引矩阵”技巧后代码变得更快,所以我想问的主要问题是是否有办法不使用这个技巧,因为它需要更多 memory。简而言之,我想有一个在时间和空间上都高效的解决方案。
Edit2:我测试的一些例子
# a simple 2D Gaussian
def calculate_element(i, j, i_mid, j_mid, i_sig, j_sig):
gaus_i = np.exp(-((i - i_mid)**2) / (2 * i_sig**2))
gaus_j = np.exp(-((j - j_mid)**2) / (2 * j_sig**2))
return gaus_i * gaus_j
# size of M, N
M, N = 1200, 4000
# use for loops to go through every element
# this code takes ~10 seconds
def main_1():
arr = np.zeros((M, N)) # (M, N) is the shape of the array
for i in range(M):
for j in range(N):
arr[i][j] = calculate_element(i, j, 600, 2000, 300, 500)
# print(arr)
plt.figure(figsize=(8, 5))
plt.imshow(arr, aspect='auto', origin='lower')
plt.show()
# use index matrices
# this code takes <1 second
def main_2():
index_matrix_i = np.array([range(M)] * N).T
index_matrix_j = np.array([range(N)] * M)
arr = calculate_element(index_matrix_i, index_matrix_j, 600, 2000, 300, 500)
# print(arr)
plt.figure(figsize=(8, 5))
plt.imshow(arr, aspect='auto', origin='lower')
plt.show()
您可以使用np.indices()
生成所需的 output:
例如,
np.indices((3, 4))
输出:
[[[0 0 0 0]
[1 1 1 1]
[2 2 2 2]]
[[0 1 2 3]
[0 1 2 3]
[0 1 2 3]]]
矢量化numpy比我的 2 核机器上的普通 jitted numba更快
import numpy as np
import matplotlib.pyplot as plt
M, N = 1200, 4000
i = np.arange(M)
j = np.arange(N)
i_mid, j_mid, i_sig, j_sig = 600, 2000, 300, 500
arr = np.exp(-(i - i_mid)**2 / (2 * i_sig**2))[:,None] * np.exp(-(j - j_mid)**2 / (2 * j_sig**2))
# %timeit 100 loops, best of 5: 8.82 ms per loop
plt.figure(figsize=(8, 5))
plt.imshow(arr, aspect='auto', origin='lower')
plt.show()
并联并行numba
import numba as nb # tested with numba 0.55.1
@nb.njit(parallel=True)
def calculate_element_nb(i, j, i_mid, j_mid, i_sig, j_sig):
res = np.empty((i,j), np.float32)
for i in nb.prange(res.shape[0]):
for j in range(res.shape[1]):
res[i,j] = np.exp(-(i - i_mid)**2 / (2 * i_sig**2)) * np.exp(-(j - j_mid)**2 / (2 * j_sig**2))
return res
M, N = 1200, 4000
calculate_element_nb(M, N, 600, 2000, 300, 500)
# %timeit 10 loops, best of 5: 80.4 ms per loop
plt.figure(figsize=(8, 5))
plt.imshow(calculate_element_nb(M, N, 600, 2000, 300, 500), aspect='auto', origin='lower')
plt.show()
您可以使用单个循环来填充多维列表,在完成所有元素后,它将像这样转换为np.array
:
import numpy as np
m, n = 5, 5
arr = []
for i in range(0, m*n, n):
arr.append(list(range(i, i+n)))
print(np.array(arr))
Output:
[[ 0 1 2 3 4]
[ 5 6 7 8 9]
[10 11 12 13 14]
[15 16 17 18 19]
[20 21 22 23 24]]
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.