[英]Numpy einsum outer sum of 2d-arrays
我試圖尋找答案,但找不到我需要的東西。 抱歉,如果這是一個重復的問題。
假設我有一個形狀為(n, n*m)
的二維數組。 我想做的是此數組對其轉置的外部求和,從而得到形狀為(n*m, n*m)
的數組。 例如,假設我有
A = array([[1., 1., 2., 2.],
[1., 1., 2., 2.]])
我想做一個A
和AT
總和,這樣輸出是:
>>> array([[2., 2., 3., 3.],
[2., 2., 3., 3.],
[3., 3., 4., 4.],
[3., 3., 4., 4.]])
請注意, np.add.outer
無效,因為它會將輸入中的向量散亂化。 我可以通過做類似的事情
np.tile(A, (2, 1)) + np.tile(A.T, (1, 2))
但是,當n
和m
較大時( n > 100
和m > 1000
),這似乎並不合理。 可以使用einsum
來寫這個總和嗎? 我只是無法弄清einsum
。
要利用broadcasting
,我們需要將其分解為3D
,然后置換軸並添加-
n = A.shape[0]
m = A.shape[1]//n
a = A.reshape(n,m,n) # reshape to 3D
out = (a[None,:,:,:] + a.transpose(1,2,0)[:,:,None,:]).reshape(n*m,-1)
樣品運行以進行驗證-
In [359]: # Setup input array
...: np.random.seed(0)
...: n,m = 3,4
...: A = np.random.randint(1,10,(n,n*m))
In [360]: # Original soln
...: out0 = np.tile(A, (m, 1)) + np.tile(A.T, (1, m))
In [361]: # Posted soln
...: n = A.shape[0]
...: m = A.shape[1]//n
...: a = A.reshape(n,m,n)
...: out = (a[None,:,:,:] + a.transpose(1,2,0)[:,:,None,:]).reshape(n*m,-1)
In [362]: np.allclose(out0, out)
Out[362]: True
n
較大的時間m
In [363]: # Setup input array
...: np.random.seed(0)
...: n,m = 100,100
...: A = np.random.randint(1,10,(n,n*m))
In [364]: %timeit np.tile(A, (m, 1)) + np.tile(A.T, (1, m))
1 loop, best of 3: 407 ms per loop
In [365]: %%timeit
...: # Posted soln
...: n = A.shape[0]
...: m = A.shape[1]//n
...: a = A.reshape(n,m,n)
...: out = (a[None,:,:,:] + a.transpose(1,2,0)[:,:,None,:]).reshape(n*m,-1)
1 loop, best of 3: 219 ms per loop
numexpr
進一步提高性能
我們可以利用具有numexpr
模塊multi-core
來處理大數據並獲得內存效率,從而提高性能-
import numexpr as ne
n = A.shape[0]
m = A.shape[1]//n
a = A.reshape(n,m,n)
p1 = a[None,:,:,:]
p2 = a.transpose(1,2,0)[:,:,None,:]
out = ne.evaluate('p1+p2').reshape(n*m,-1)
n
, m
設置相同的定時-
In [367]: %%timeit
...: # Posted soln
...: n = A.shape[0]
...: m = A.shape[1]//n
...: a = A.reshape(n,m,n)
...: p1 = a[None,:,:,:]
...: p2 = a.transpose(1,2,0)[:,:,None,:]
...: out = ne.evaluate('p1+p2').reshape(n*m,-1)
10 loops, best of 3: 152 ms per loop
一種方法是
(A.reshape(-1,*A.shape).T+A)[:,0,:]
我認為這將占用n>100
和m>1000
的大量內存。
但這不一樣
np.add.outer(A,A)[:,0,:].reshape(4,-1)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.