簡體   English   中英

從 Numpy 或 Tensorflow 中的線性數組矢量化創建對角方形數組的數組

[英]Vectorized creation of an array of diagonal square arrays from a liner array in Numpy or Tensorflow

我有一個形狀為[batch_size, N]的數組,例如:

[[1  2]
 [3  4]
 [5  6]]

我需要創建一個形狀為[batch_size, N, N]的 3 個索引數組,其中對於每個batch我都有一個N x N對角矩陣,其中對角線由相應的batch元素獲取,例如在這種情況下,在這個簡單的案例,我正在尋找的結果是:

[
  [[1,0],[0,2]],
  [[3,0],[0,4]],
  [[5,0],[0,6]],
]

如何在沒有 for 循環和探索矢量化的情況下進行此操作? 我猜這是維度的擴展,但我找不到正確的函數來做到這一點。 (我需要它,因為我正在使用 tensorflow 和 numpy 進行原型設計)。

在tensorflow中嘗試一下:

import tensorflow as tf
A = [[1,2],[3 ,4],[5,6]]
B = tf.matrix_diag(A)
print(B.eval(session=tf.Session()))
[[[1 0]
  [0 2]]

 [[3 0]
  [0 4]]

 [[5 0]
  [0 6]]]

方法1

這里有一個量化的一個與np.einsum輸入陣列, a -

# Initialize o/p array
out = np.zeros(a.shape + (a.shape[1],),dtype=a.dtype)

# Get diagonal view and assign into it input array values
diag = np.einsum('ijj->ij',out)
diag[:] = a

方法#2

另一個基於切片的分配-

m,n = a.shape
out = np.zeros((m,n,n),dtype=a.dtype)
out.reshape(-1,n**2)[...,::n+1] = a

可以使用numpy.diag

m = [[1, 2],
 [3, 4],
 [5, 6]]

[np.diag(b) for b in m]

編輯下圖顯示了上述解決方案的平均執行時間(實線),並將其與@Divakar的解決方案(虛線)進行了比較,以了解不同的批處理大小和不同的矩陣大小

在此處輸入圖片說明

我不相信您會獲得太多改進,但這只是基於此簡單指標

np.expand_dimsnp.expand_dims的按元素乘積一起np.eye

a = np.array([[1,  2],
              [3,  4],
              [5, 6]])
N = a.shape[1]
a = np.expand_dims(a, axis=1)
a*np.eye(N)

array([[[1., 0.],
       [0., 2.]],

      [[3., 0.],
       [0., 4.]],

      [[5., 0.],
       [0., 6.]]])

說明

np.expand_dims(a, axis=1)a添加一個新軸,該軸現在將是(3, 1, 2) 3,1,2 (3, 1, 2) ndarray:

array([[[1, 2]],

       [[3, 4]],

       [[5, 6]]])

現在,您可以將此數組乘以大小為N單位矩陣,可以使用np.eye生成該矩陣:

np.eye(N)
array([[1., 0.],
       [0., 1.]])

這將產生所需的輸出:

a*np.eye(N)

array([[[1., 0.],
        [0., 2.]],

       [[3., 0.],
        [0., 4.]],

       [[5., 0.],
        [0., 6.]]])

您基本上想要一個與np.block(..)

我需要同樣的東西,所以我寫了這個小函數:

def split_blocks(x, m=2, n=2):
    """
    Reverse the action of np.block(..)

    >>> x = np.random.uniform(-1, 1, (2, 18, 20))
    >>> assert (np.block(split_blocks(x, 3, 4)) == x).all()

    :param x: (.., M, N) input matrix to split into blocks
    :param m: number of row splits
    :param n: number of column, splits
    :return:
    """
    x = np.array(x, copy=False)
    nd = x.ndim
    *shape, nr, nc = x.shape
    return list(map(list, x.reshape((*shape, m, nr//m, n, nc//n)).transpose(nd-2, nd, *range(nd-2), nd-1, nd+1)))

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM