[英]Vectorized creation of an array of diagonal square arrays from a liner array in Numpy or Tensorflow
我有一個形狀為[batch_size, N]
的數組,例如:
[[1 2]
[3 4]
[5 6]]
我需要創建一個形狀為[batch_size, N, N]
的 3 個索引數組,其中對於每個batch
我都有一個N x N
對角矩陣,其中對角線由相應的batch
元素獲取,例如在這種情況下,在這個簡單的案例,我正在尋找的結果是:
[
[[1,0],[0,2]],
[[3,0],[0,4]],
[[5,0],[0,6]],
]
如何在沒有 for 循環和探索矢量化的情況下進行此操作? 我猜這是維度的擴展,但我找不到正確的函數來做到這一點。 (我需要它,因為我正在使用 tensorflow 和 numpy 進行原型設計)。
在tensorflow中嘗試一下:
import tensorflow as tf
A = [[1,2],[3 ,4],[5,6]]
B = tf.matrix_diag(A)
print(B.eval(session=tf.Session()))
[[[1 0]
[0 2]]
[[3 0]
[0 4]]
[[5 0]
[0 6]]]
方法1
這里有一個量化的一個與np.einsum
輸入陣列, a
-
# Initialize o/p array
out = np.zeros(a.shape + (a.shape[1],),dtype=a.dtype)
# Get diagonal view and assign into it input array values
diag = np.einsum('ijj->ij',out)
diag[:] = a
方法#2
另一個基於切片的分配-
m,n = a.shape
out = np.zeros((m,n,n),dtype=a.dtype)
out.reshape(-1,n**2)[...,::n+1] = a
可以使用numpy.diag
m = [[1, 2],
[3, 4],
[5, 6]]
[np.diag(b) for b in m]
編輯下圖顯示了上述解決方案的平均執行時間(實線),並將其與@Divakar的解決方案(虛線)進行了比較,以了解不同的批處理大小和不同的矩陣大小
我不相信您會獲得太多改進,但這只是基於此簡單指標
將np.expand_dims
與np.expand_dims
的按元素乘積一起np.eye
a = np.array([[1, 2],
[3, 4],
[5, 6]])
N = a.shape[1]
a = np.expand_dims(a, axis=1)
a*np.eye(N)
array([[[1., 0.],
[0., 2.]],
[[3., 0.],
[0., 4.]],
[[5., 0.],
[0., 6.]]])
np.expand_dims(a, axis=1)
向a
添加一個新軸,該軸現在將是(3, 1, 2)
3,1,2 (3, 1, 2)
ndarray:
array([[[1, 2]],
[[3, 4]],
[[5, 6]]])
現在,您可以將此數組乘以大小為N
單位矩陣,可以使用np.eye
生成該矩陣:
np.eye(N)
array([[1., 0.],
[0., 1.]])
這將產生所需的輸出:
a*np.eye(N)
array([[[1., 0.],
[0., 2.]],
[[3., 0.],
[0., 4.]],
[[5., 0.],
[0., 6.]]])
您基本上想要一個與np.block(..)
我需要同樣的東西,所以我寫了這個小函數:
def split_blocks(x, m=2, n=2):
"""
Reverse the action of np.block(..)
>>> x = np.random.uniform(-1, 1, (2, 18, 20))
>>> assert (np.block(split_blocks(x, 3, 4)) == x).all()
:param x: (.., M, N) input matrix to split into blocks
:param m: number of row splits
:param n: number of column, splits
:return:
"""
x = np.array(x, copy=False)
nd = x.ndim
*shape, nr, nc = x.shape
return list(map(list, x.reshape((*shape, m, nr//m, n, nc//n)).transpose(nd-2, nd, *range(nd-2), nd-1, nd+1)))
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.