[英]Unexpected result with numpy.dot
我有兩個矩陣:
>>> a.shape
(100, 3, 1)
>>> b.shape
(100, 3, 3)
我想執行一個點積,這樣我的最終結果是 (100, 3, 1)。 但是,目前我收到:
>>> c = np.dot(b, a)
>>> c.shape
(100, 3, 100, 1)
有人可以解釋發生了什么嗎? 我正在閱讀文檔,但無法弄清楚。
編輯:
所以根據文檔(忽略它):
如果 a 和 b 都是二維數組,則是矩陣乘法,但首選使用 matmul 或 a @ b。
這給出了想要的結果,但我仍然很好奇,這里發生了什么? 應用np.dot
函數的什么規則來產生(100, 3, 100, 1)
?
這就是 dot 在您的情況下的工作方式:
dot(b, a)[i,j,k,m] = sum(b[i,j,:] * a[k,:,m])
您的輸出形狀正是文檔指定的方式:
(b.shape[0], b.shape[1], a.shape[0], a.shape[2])
如果這不是您所期望的,您可能正在尋找另一個矩陣乘法。
dot
將返回存儲在數組最后兩個維度中的矩陣的所有可能乘積。 使用matmul
aka @
運算符來廣播前導維度而不是組合它們:
np.matmul(b, a)
或者
b @ a
sum-products 的einsum
是einsum
,所以你也可以使用它:
np.einsum('aij,ajk->aik', b, a)
或者
np.einsum('ajk,aij->aik', a, b)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.