So I am trying to select some columns from a 3D matrix based on the values in a vector using Numpy. I have already solved the problem using a list comprehension, but I figured that there might be a better way using Numpy's builtin methods. Does anyone know if such a method or combination of methods exist?
matrix1 = np.array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
matrix2 = np.array([[10, 11, 12],
[13, 14, 15],
[16, 17, 18]])
total_matrix = np.array([matrix1, matrix2])
vector = [0,1,1]
# Retrieve the first column from the first matrix, second and third from the second matrix.
result = np.array([total_matrix[index2,: , index1] for index1, index2 in enumerate(vector)]).transpose()
# result:
np.array([[1, 11, 12],
[4, 14, 15],
[7, 15, 18]])
In [58]: total_matrix[vector, np.arange(3)[:,None], np.arange(3)]
Out[58]:
array([[ 1, 11, 12],
[ 4, 14, 15],
[ 7, 17, 18]])
vector
indexes the first dimension. The other 2 broadcast with it to select the required (3,3). While I knew the general principle, I tried a number of variations (about 9) before getting the right one.
The use of diagonal
in the other answer is equivalent to doing:
In [61]: total_matrix[vector][:,np.arange(3),np.arange(3)]
Out[61]:
array([[ 1, 5, 9],
[10, 14, 18],
[10, 14, 18]])
You can slice the total_matrix
using your vector
and then select appropriate diagonal elements of it:
>>> np.diagonal(total_matrix[vector], axis1=0, axis2=2)
array([[ 1, 11, 12],
[ 4, 14, 15],
[ 7, 17, 18]])
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.