Given a 2d Numpy array, I would like to be able to compute the diagonal for each row in the fastest way possible, I'm right now using a list comprehension but I'm wondering if it can be vectorised somehow?
For example using the following M array:
M = np.random.rand(5, 3)
[[ 0.25891593 0.07299478 0.36586996]
[ 0.30851087 0.37131459 0.16274825]
[ 0.71061831 0.67718718 0.09562581]
[ 0.71588836 0.76772047 0.15476079]
[ 0.92985142 0.22263399 0.88027331]]
I would like to compute the following array:
np.array([np.diag(row) for row in M])
array([[[ 0.25891593, 0. , 0. ],
[ 0. , 0.07299478, 0. ],
[ 0. , 0. , 0.36586996]],
[[ 0.30851087, 0. , 0. ],
[ 0. , 0.37131459, 0. ],
[ 0. , 0. , 0.16274825]],
[[ 0.71061831, 0. , 0. ],
[ 0. , 0.67718718, 0. ],
[ 0. , 0. , 0.09562581]],
[[ 0.71588836, 0. , 0. ],
[ 0. , 0.76772047, 0. ],
[ 0. , 0. , 0.15476079]],
[[ 0.92985142, 0. , 0. ],
[ 0. , 0.22263399, 0. ],
[ 0. , 0. , 0.88027331]]])
Here's one way using element-wise multiplication of np.eye(3)
(the 3x3 identity array) and a slightly re-shaped M
:
>>> M = np.random.rand(5, 3)
>>> np.eye(3) * M[:,np.newaxis,:]
array([[[ 0.42527357, 0. , 0. ],
[ 0. , 0.17557419, 0. ],
[ 0. , 0. , 0.61920924]],
[[ 0.04991268, 0. , 0. ],
[ 0. , 0.74000307, 0. ],
[ 0. , 0. , 0.34541354]],
[[ 0.71464307, 0. , 0. ],
[ 0. , 0.11878955, 0. ],
[ 0. , 0. , 0.65411844]],
[[ 0.01699954, 0. , 0. ],
[ 0. , 0.39927673, 0. ],
[ 0. , 0. , 0.14378892]],
[[ 0.5209439 , 0. , 0. ],
[ 0. , 0.34520876, 0. ],
[ 0. , 0. , 0.53862677]]])
(By "re-shaped M
" I mean that the rows of M
are made to face out along the z-axis rather than across the y-axis, giving M
the shape (5, 1, 3)
.)
Despite the good answer of @ajcr, a much faster alternative can be achieved with fancy indexing (tested in NumPy 1.9.0):
import numpy as np
def sol0(M):
return np.eye(M.shape[1]) * M[:,np.newaxis,:]
def sol1(M):
b = np.zeros((M.shape[0], M.shape[1], M.shape[1]))
diag = np.arange(M.shape[1])
b[:, diag, diag] = M
return b
where the timing shows this is approximately 4X faster:
M = np.random.random((1000, 3))
%timeit sol0(M)
#10000 loops, best of 3: 111 µs per loop
%timeit sol1(M)
#10000 loops, best of 3: 23.8 µs per loop
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.