简体   繁体   中英

Generating indices of a 2D NumPy array

I want to generate a 2D numpy array with elements calculated from their positions. Something like the following code:

import numpy as np

def calculate_element(i, j, other_parameters):
    # do something
    return value_at_i_j

def main():
    arr = np.zeros((M, N))  # (M, N) is the shape of the array
    for i in range(M):
        for j in range(N):
            arr[i][j] = calculate_element(i, j, ...)

This code runs extremely slow since the loops in Python are just not very efficient. Is there any way to do this faster in this case?

By the way, for now I use a workaround by calculating two 2D "index matrices". Something like this:

def main():
    index_matrix_i = np.array([range(M)] * N).T
    index_matrix_j = np.array([range(N)] * M)

    '''
    index_matrix_i is like
    [[0,0,0,...],
     [1,1,1,...],
     [2,2,2,...],
     ...
    ]

    index_matrix_j is like
    [[0,1,2,...],
     [0,1,2,...],
     [0,1,2,...],
     ...
    ]
    '''

    arr = calculate_element(index_matrix_i, index_matrix_j, ...)

Edit1: The code becomes much faster after I apply the "index matrices" trick, so the main question I want to ask is that if there is a way to not use this trick, since it takes more memory. In short, I want to have a solution that is efficient in both time and space .

Edit2: Some examples I tested

# a simple 2D Gaussian
def calculate_element(i, j, i_mid, j_mid, i_sig, j_sig):
    gaus_i = np.exp(-((i - i_mid)**2) / (2 * i_sig**2))
    gaus_j = np.exp(-((j - j_mid)**2) / (2 * j_sig**2))
    return gaus_i * gaus_j
# size of M, N
M, N = 1200, 4000
# use for loops to go through every element
# this code takes ~10 seconds
def main_1():
    arr = np.zeros((M, N))  # (M, N) is the shape of the array
    for i in range(M):
        for j in range(N):
            arr[i][j] = calculate_element(i, j, 600, 2000, 300, 500)
    # print(arr)
    plt.figure(figsize=(8, 5))
    plt.imshow(arr, aspect='auto', origin='lower')
    plt.show()
# use index matrices
# this code takes <1 second
def main_2():
    index_matrix_i = np.array([range(M)] * N).T
    index_matrix_j = np.array([range(N)] * M)
    arr = calculate_element(index_matrix_i, index_matrix_j, 600, 2000, 300, 500)

    # print(arr)
    plt.figure(figsize=(8, 5))
    plt.imshow(arr, aspect='auto', origin='lower')
    plt.show()

You can use np.indices() to generate the desired output:

For example,

np.indices((3, 4))

outputs:

[[[0 0 0 0]
  [1 1 1 1]
  [2 2 2 2]]

 [[0 1 2 3]
  [0 1 2 3]
  [0 1 2 3]]]

Vectorized is faster than trivially jitted on my 2-core machine

import numpy as np
import matplotlib.pyplot as plt

M, N = 1200, 4000
i = np.arange(M)
j = np.arange(N)
i_mid, j_mid, i_sig, j_sig = 600, 2000, 300, 500

arr = np.exp(-(i - i_mid)**2 / (2 * i_sig**2))[:,None] * np.exp(-(j - j_mid)**2 / (2 * j_sig**2))
# %timeit 100 loops, best of 5: 8.82 ms per loop

plt.figure(figsize=(8, 5))
plt.imshow(arr, aspect='auto', origin='lower')
plt.show()

麻木的结果

Jitted parallel numba

import numba as nb  # tested with numba 0.55.1

@nb.njit(parallel=True)
def calculate_element_nb(i, j, i_mid, j_mid, i_sig, j_sig):
    res = np.empty((i,j), np.float32)
    for i in nb.prange(res.shape[0]):
        for j in range(res.shape[1]):
            res[i,j] = np.exp(-(i - i_mid)**2 / (2 * i_sig**2)) * np.exp(-(j - j_mid)**2 / (2 * j_sig**2))
    return res

M, N = 1200, 4000

calculate_element_nb(M, N, 600, 2000, 300, 500)
# %timeit 10 loops, best of 5: 80.4 ms per loop

plt.figure(figsize=(8, 5))
plt.imshow(calculate_element_nb(M, N, 600, 2000, 300, 500), aspect='auto', origin='lower')
plt.show()

numba 结果

You can use a single loop to fill a multidimensional list which after completing all its elements it will be converted to np.array like this:

import numpy as np

m, n = 5, 5
arr = []
for i in range(0, m*n, n):
    arr.append(list(range(i, i+n)))
print(np.array(arr))

Output:

[[ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]
 [15 16 17 18 19]
 [20 21 22 23 24]]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM