[英]Construct (N+1)-dimensional diagonal matrix from values in N-dimensional array
我有一个N维数组。 我想通过将最终尺寸的值放在对角线上来将其扩展为(N + 1)维数组。
例如,使用显式循环:
In [197]: M = arange(5*3).reshape(5, 3)
In [198]: numpy.dstack([numpy.diag(M[i, :]) for i in range(M.shape[0])]).T
Out[198]:
array([[[ 0, 0, 0],
[ 0, 1, 0],
[ 0, 0, 2]],
[[ 3, 0, 0],
[ 0, 4, 0],
[ 0, 0, 5]],
[[ 6, 0, 0],
[ 0, 7, 0],
[ 0, 0, 8]],
[[ 9, 0, 0],
[ 0, 10, 0],
[ 0, 0, 11]],
[[12, 0, 0],
[ 0, 13, 0],
[ 0, 0, 14]]])
这是一个5×3×3阵列。
我的实际数组很大,我想避免显式循环(在map
隐藏循环而不是列表理解没有性能增益;它仍然是一个循环)。 尽管numpy.diag
适用于构造规则的二维对角矩阵,但它不会扩展到更高的维度(当给定二维数组时,它将提取其对角线)。 numpy.diagflat
返回的数组使一切成为一个大的对角线,产生一个15×15的阵列,它有更多的零,不能重新形成5×3×3。
有没有办法从N维数组中的值有效地构造(N + 1) - 对角矩阵,而不需要多次调用diag
?
使用numpy.diagonal
来查看正确形状的N + 1维数组的相关对角线,强制视图可以使用setflags
写入,并写入视图:
expanded = numpy.zeros(M.shape + M.shape[-1:], dtype=M.dtype)
diagonals = numpy.diagonal(expanded, axis1=-2, axis2=-1)
diagonals.setflags(write=True)
diagonals[:] = M
这会产生expanded
所需的数组。
你可以使用普遍存在的np.einsum
几乎不可能猜到的你不知道的功能。 当使用如下时, einsum
将返回广义对角线的可写视图:
>>> import numpy as np
>>> M = np.arange(5*3).reshape(5, 3)
>>>
>>> out = np.zeros((*M.shape, M.shape[-1]), M.dtype)
>>> np.einsum('...jj->...j', out)[...] = M
>>> out
array([[[ 0, 0, 0],
[ 0, 1, 0],
[ 0, 0, 2]],
[[ 3, 0, 0],
[ 0, 4, 0],
[ 0, 0, 5]],
[[ 6, 0, 0],
[ 0, 7, 0],
[ 0, 0, 8]],
[[ 9, 0, 0],
[ 0, 10, 0],
[ 0, 0, 11]],
[[12, 0, 0],
[ 0, 13, 0],
[ 0, 0, 14]]])
将ND数组的最后一个维度转换为对角矩阵的一般方法:
我们需要减少数组的维数,将numpy.diag()
函数应用于每个向量,然后将其重建为原始维度+ 1。
将矩阵重塑为二维:
M.reshape(-1, M.shape[-1])
然后使用map
将np.diag
应用于此,并使用以下内容重建具有附加维度的矩阵:
result.reshape([*M.shape, M.shape[-1]])
所有这些结合起来给出了以下内容:
result = np.array(list(map(
np.diag,
M.reshape(-1, M.shape[-1])
))).reshape([*M.shape, M.shape[-1]])
一个例子:
shape = np.arange(2,8)
M = np.arange(shape.prod()).reshape(shape)
print(M.shape) # (2, 3, 4, 5, 6, 7)
result = np.array(list(map(np.diag, M.reshape(-1, M.shape[-1])))).reshape([*M.shape, M.shape[-1]])
print(result.shape) # (2, 3, 4, 5, 6, 7, 7)
和res[0,0,0,0,2,:]
包含以下内容:
array([[14, 0, 0, 0, 0, 0, 0],
[ 0, 15, 0, 0, 0, 0, 0],
[ 0, 0, 16, 0, 0, 0, 0],
[ 0, 0, 0, 17, 0, 0, 0],
[ 0, 0, 0, 0, 18, 0, 0],
[ 0, 0, 0, 0, 0, 19, 0],
[ 0, 0, 0, 0, 0, 0, 20]])
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.