![](/img/trans.png)
[英]How to get a cartesian-product of all pair from two vectors in numpy?
[英]Numpy: create a matrix from a cartesian product of two vectors (or one with itself) while applying a function to all pairs
為了說明一下,我想創建一個協方差矩陣,其中每個元素都由內核函數k(x, y)
,並且我想對單個向量執行此操作。 應該是這樣的:
# This is given
x = [x1, x2, x3, x4, ...]
# This is what I want to compute
result = [[k(x1, x1), k(x1, x2), k(x1, x3), ...],
[k(x2, x1), k(x2, x2), ...],
[k(x3, x1), k(x3, x2), ...],
...]
但由於性能,當然應該在numpy數組中完成,理想情況下,不進行Python交互。 如果我不關心性能,我可能會寫:
result = np.zeros((len(x), len(x)))
for i in range(len(x)):
for j in range(len(x)):
result[i, j] = k(x[i], x[j])
但是我覺得必須有一種更慣用的方式來編寫這種模式。
如果k
在2D數組上運算,則可以使用np.meshgrid
。 但是,這會產生額外的內存開銷。 一種替代方法是創建與np.meshgrid
相同的2D
網格視圖,如下所示-
def meshgrid1D_view(x):
shp = (len(x),len(x))
mesh1 = np.broadcast_to(x,shp)
mesh2 = np.broadcast_to(x[:,None],shp)
return mesh1, mesh2
樣品運行-
In [140]: x
Out[140]: array([3, 5, 6, 8])
In [141]: np.meshgrid(x,x)
Out[141]:
[array([[3, 5, 6, 8],
[3, 5, 6, 8],
[3, 5, 6, 8],
[3, 5, 6, 8]]), array([[3, 3, 3, 3],
[5, 5, 5, 5],
[6, 6, 6, 6],
[8, 8, 8, 8]])]
In [142]: meshgrid1D(x)
Out[142]:
(array([[3, 5, 6, 8],
[3, 5, 6, 8],
[3, 5, 6, 8],
[3, 5, 6, 8]]), array([[3, 3, 3, 3],
[5, 5, 5, 5],
[6, 6, 6, 6],
[8, 8, 8, 8]]))
這有什么幫助?
它有助於提高內存效率,從而提高性能。 讓我們在大型陣列上進行測試以了解差異-
In [143]: x = np.random.randint(0,10,(10000))
In [144]: %timeit np.meshgrid(x,x)
10 loops, best of 3: 171 ms per loop
In [145]: %timeit meshgrid1D(x)
100000 loops, best of 3: 6.91 µs per loop
另一個解決方案是讓numpy自己進行廣播:
import numpy as np
def k(x,y):
return x**2+y
def meshgrid1D_view(x):
shp = (len(x),len(x))
mesh1 = np.broadcast_to(x,shp)
mesh2 = np.broadcast_to(x[:,None],shp)
return mesh1, mesh2
x = np.random.randint(0,10,(10000))
b=k(a[:,None],a[None,:])
def sol0(x):
k(x[:,None],x[None,:])
def sol1(x):
x,y=np.meshgrid(x,x)
k(x,y)
def sol2(x):
x,y=meshgrid1D_view(x)
k(x,y)
%timeit sol0(x)
165 ms ± 1.67 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit sol1(x)
655 ms ± 6.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit sol2(x)
341 ms ± 2.91 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
您會看到這更有效,並且代碼更少。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.