[英]Efficient way of constructing a 3D stack of block diagonal matrix in numpy/scipy from a 3D stack of matrices
我正在尝试从给定的矩阵堆栈(nXmXm)中以 nXMXM 的形式在 numpy/scipy 中构造一个块对角矩阵堆栈,其中 M=k*m,k 是矩阵堆栈的数量。 目前,我在 for 循环中使用 scipy.linalg.block_diag function 来执行此任务:
import numpy as np
import scipy.linalg as linalg
a = np.ones((5,2,2))
b = np.ones((5,2,2))
c = np.ones((5,2,2))
result = np.zeros((5,6,6))
for k in range(0,5):
result[k,:,:] = linalg.block_diag(a[k,:,:],b[k,:,:],c[k,:,:])
但是,由于在我的情况下 n 变得非常大,因此我正在寻找一种比 for 循环更有效的方法。 我发现3D numpy 数组成块对角矩阵,但这并不能真正解决我的问题。 我能想象的任何事情都是将每个矩阵堆栈转换为块对角线
import numpy as np
import scipy.linalg as linalg
a = np.ones((5,2,2))
b = np.ones((5,2,2))
c = np.ones((5,2,2))
a = linalg.block_diag(*a)
b = linalg.block_diag(*b)
c = linalg.block_diag(*c)
并通过重塑从中构造结果矩阵
result = linalg.block_diag(a,b,c)
result = result.reshape((5,6,6))
这不会重塑。 我什至不知道,如果这种方法会更有效,所以我问我是否走在正确的轨道上,或者是否有人知道构建这个块对角 3D 矩阵的更好方法,或者我是否必须坚持for 循环解决方案。
编辑:由于我是这个平台的新手,我不知道该把它留在哪里(编辑或回答?),但我想分享我的最终解决方案:panadestein 的 highlightet 解决方案非常好用且简单,但我m 现在使用更高维度的 arrays,我的矩阵位于最后两个维度。 此外,我的矩阵不再具有相同的维度(主要是 1x1、2x2、3x3 的混合),因此我采用了 V. Ayrat 的解决方案,并进行了微小的更改:
def nd_block_diag(arrs):
shapes = np.array([i.shape for i in arrs])
out = np.zeros(np.append(np.amax(shapes[:,:-2],axis=0), [shapes[:,-2].sum(), shapes[:,-1].sum()]))
r, c = 0, 0
for i, (rr, cc) in enumerate(shapes[:,-2:]):
out[..., r:r + rr, c:c + cc] = arrs[i]
r += rr
c += cc
return out
如果输入 arrays 的形状正确(即要广播的尺寸不会自动添加),它也适用于数组广播。 感谢 pandestein 和 V. Ayrat 的友好和快速的帮助,我学到了很多关于列表推导和数组索引/切片的可能性!
我不认为你可以逃避所有可能的循环来解决你的问题。 我发现一种比您的for
循环更方便且可能更有效的方法是使用列表推导:
import numpy as np
from scipy.linalg import block_diag
# Define input matrices
a = np.ones((5, 2, 2))
b = np.ones((5, 2, 2))
c = np.ones((5, 2, 2))
# Generate block diagonal matrices
mats = np.array([a, b, c]).reshape(5, 3, 2, 2)
result = [block_diag(*bmats) for bmats in mats]
也许这可以为您提供一些改进实施的想法。
block_diag也只是遍历形状。 几乎所有时间都花在复制数据上,因此您可以按照自己的方式进行操作,例如只需更改block_diag
的源代码
arrs = a, b, c
shapes = np.array([i.shape for i in arrs])
out = np.zeros([shapes[0, 0], shapes[:, 1].sum(), shapes[:, 2].sum()])
r, c = 0, 0
for i, (_, rr, cc) in enumerate(shapes):
out[:, r:r + rr, c:c + cc] = arrs[i]
r += rr
c += cc
print(np.allclose(result, out))
# True
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.