[英]How numpy.tensordot command works?and what is the meaning of summing over axis in this command?
[英]How does numpy.tensordot function works step-by-step?
我是 numpy 的新手,所以我在可視化numpy.tensordot()
函數的工作時numpy.tensordot()
一些問題。 根據tensordot
的文檔,軸在參數中傳遞,其中軸=0 或 1 表示正常矩陣乘法,而軸=2 表示收縮。
有人可以解釋一下乘法將如何處理給定的例子嗎?
示例 1:
a=[1,1] b=[2,2] for axes=0,1
為什么它會在 axes=2 時引發錯誤?
示例 2:a=[[1,1],[1,1]] b=[[2,2],[2,2]] for axes=0,1,2
編輯:此答案的最初重點是axes
是元組的情況,為每個參數指定一個或多個軸。 這種用法允許我們對傳統的dot
進行變體,特別是對於大於 2d 的數組(我在鏈接問題中的回答也是, https://stackoverflow.com/a/41870980/901925 )。 作為標量的軸是一種特殊情況,它被翻譯成元組版本。 所以它的核心仍然是一個dot
積。
In [235]: a=[1,1]; b=[2,2]
a
和b
是列表; tensordot
將它們變成數組。
In [236]: np.tensordot(a,b,(0,0))
Out[236]: array(4)
由於它們都是一維數組,我們將軸值指定為 0。
如果我們嘗試指定 1:
In [237]: np.tensordot(a,b,(0,1))
---------------------------------------------------------------------------
1282 else:
1283 for k in range(na):
-> 1284 if as_[axes_a[k]] != bs[axes_b[k]]:
1285 equal = False
1286 break
IndexError: tuple index out of range
它正在檢查a
的軸 0 的大小是否與b
的軸 1 的大小匹配。 但由於b
是 1d,它無法檢查。
In [239]: np.array(a).shape[0]
Out[239]: 2
In [240]: np.array(b).shape[1]
IndexError: tuple index out of range
你的第二個例子是二維數組:
In [242]: a=np.array([[1,1],[1,1]]); b=np.array([[2,2],[2,2]])
指定的最后一個軸a
和第一b
(第二到最后),產生傳統的矩陣(點)產物:
In [243]: np.tensordot(a,b,(1,0))
Out[243]:
array([[4, 4],
[4, 4]])
In [244]: a.dot(b)
Out[244]:
array([[4, 4],
[4, 4]])
更好的診斷值:
In [250]: a=np.array([[1,2],[3,4]]); b=np.array([[2,3],[2,1]])
In [251]: np.tensordot(a,b,(1,0))
Out[251]:
array([[ 6, 5],
[14, 13]])
In [252]: np.dot(a,b)
Out[252]:
array([[ 6, 5],
[14, 13]])
In [253]: np.tensordot(a,b,(0,1))
Out[253]:
array([[11, 5],
[16, 8]])
In [254]: np.dot(b,a) # same numbers, different layout
Out[254]:
array([[11, 16],
[ 5, 8]])
In [255]: np.dot(b,a).T
Out[255]:
array([[11, 5],
[16, 8]])
另一個配對:
In [256]: np.tensordot(a,b,(0,0))
In [257]: np.dot(a.T,b)
(0,1,2) 軸是完全錯誤的。 軸參數應該是 2 個數字或 2 個元組,對應於 2 個參數。
tensordot
的基本處理是對輸入進行轉置和重塑,以便它可以將結果傳遞給np.dot
用於常規(a 的最后一個,b 的最后一個)矩陣乘積。
如果我對tensordot
代碼的閱讀是正確的,則axes
參數將轉換為兩個列表:
def foo(axes):
try:
iter(axes)
except Exception:
axes_a = list(range(-axes, 0))
axes_b = list(range(0, axes))
else:
axes_a, axes_b = axes
try:
na = len(axes_a)
axes_a = list(axes_a)
except TypeError:
axes_a = [axes_a]
na = 1
try:
nb = len(axes_b)
axes_b = list(axes_b)
except TypeError:
axes_b = [axes_b]
nb = 1
return axes_a, axes_b
對於標量值 0,1,2,結果為:
In [281]: foo(0)
Out[281]: ([], [])
In [282]: foo(1)
Out[282]: ([-1], [0])
In [283]: foo(2)
Out[283]: ([-2, -1], [0, 1])
axes=1
與在元組中指定相同:
In [284]: foo((-1,0))
Out[284]: ([-1], [0])
對於 2:
In [285]: foo(((-2,-1),(0,1)))
Out[285]: ([-2, -1], [0, 1])
在我的最新示例中, axes=2
與在 2 個數組的所有軸上指定一個dot
相同:
In [287]: np.tensordot(a,b,axes=2)
Out[287]: array(18)
In [288]: np.tensordot(a,b,axes=((0,1),(0,1)))
Out[288]: array(18)
這與在數組的扁平化 1d 視圖上做dot
相同:
In [289]: np.dot(a.ravel(), b.ravel())
Out[289]: 18
我已經演示了這些數組的傳統點積, axes=1
情況。
axes=0
與axes=((),())
,兩個數組沒有求和軸:
In [292]: foo(((),()))
Out[292]: ([], [])
np.tensordot(a,b,((),()))
與np.tensordot(a,b,axes=0)
當輸入數組為 1d 時, foo(2)
轉換中的-2
會給您帶來問題。 axes=1
是一維數組的“收縮”。 換句話說,不要太字面理解文檔中的描述。 他們只是試圖描述代碼的動作; 它們不是正式的規范。
我認為einsum
的軸規格更清晰、更強大。 這是 0,1,2 的等價物
In [295]: np.einsum('ij,kl',a,b)
Out[295]:
array([[[[ 2, 3],
[ 2, 1]],
[[ 4, 6],
[ 4, 2]]],
[[[ 6, 9],
[ 6, 3]],
[[ 8, 12],
[ 8, 4]]]])
In [296]: np.einsum('ij,jk',a,b)
Out[296]:
array([[ 6, 5],
[14, 13]])
In [297]: np.einsum('ij,ij',a,b)
Out[297]: 18
axis=0 的情況,相當於:
np.dot(a[:,:,None],b[:,None,:])
它添加了一個新的最后一個軸和新的第二個到最后一個軸,並對它們進行傳統的點積求和。 但是我們通常用廣播做這種“外部”乘法:
a[:,:,None,None]*b[None,None,:,:]
雖然對軸使用 0,1,2 很有趣,但它實際上並沒有增加新的計算能力。 軸的元組形式更強大和有用。
1 - 將axes
轉換為axes_a
和axes_b
如上述foo
函數中摘錄的
2 - 將a
和b
組成數組,並獲得形狀和 ndim
3 - 檢查將相加的軸上的匹配大小(收縮)
4 - 構造一個newshape_a
和newaxes_a
; b
相同(復雜步驟)
5 - at = a.transpose(newaxes_a).reshape(newshape_a)
; b
6 - res = dot(at, bt)
7 - 將res
重塑為所需的返回形狀
5和6是計算核心。 4 是概念上最復雜的步驟。 對於所有axes
值,計算都是相同的, dot
積,但設置不同。
雖然文檔只提到了標量軸的 0,1,2,但代碼不限於這些值
In [331]: foo(3)
Out[331]: ([-3, -2, -1], [0, 1, 2])
如果輸入為 3,則軸 = 3 應該可以工作:
In [330]: np.tensordot(np.ones((2,2,2)), np.ones((2,2,2)), axes=3)
Out[330]: array(8.)
或更一般地說:
In [325]: np.tensordot(np.ones((2,2,2)), np.ones((2,2,2)), axes=0).shape
Out[325]: (2, 2, 2, 2, 2, 2)
In [326]: np.tensordot(np.ones((2,2,2)), np.ones((2,2,2)), axes=1).shape
Out[326]: (2, 2, 2, 2)
In [327]: np.tensordot(np.ones((2,2,2)), np.ones((2,2,2)), axes=2).shape
Out[327]: (2, 2)
In [328]: np.tensordot(np.ones((2,2,2)), np.ones((2,2,2)), axes=3).shape
Out[328]: ()
如果輸入為 0d,則 axes=0 有效(axes = 1 無效):
In [335]: np.tensordot(2,3, axes=0)
Out[335]: array(6)
你能解釋一下嗎?
In [363]: np.tensordot(np.ones((4,2,3)),np.ones((2,3,4)),axes=2).shape
Out[363]: (4, 4)
我已經嘗試過 3d 數組的其他標量軸值。 雖然可以提出有效的形狀對,但更明確的元組軸值更容易使用。 0,1,2
選項是僅適用於特殊情況的捷徑。 元組方法更容易使用 - 盡管我仍然更喜歡einsum
表示法。
np.tensordot([1, 1], [2, 2], axes=0)
在這種情況下, a和b都具有單個軸並具有形狀(2,)
。
所述axes=0
參數可以被轉換為((最后0軸的一個),(第一個0軸線B的)),或者在這種情況下((), ())
這些是將要收縮的軸。
所有其他軸都不會收縮。 由於a和b 中的每一個都有第 0 個軸而沒有其他軸,因此這些軸是((0,), (0,))
。
然后tensordot操作如下(大致):
[
[x*y for y in b] # all the non-contraction axes in b
for x in a # all the non-contraction axes in a
]
請注意,由於a和b之間共有 2 個可用軸,並且由於我們收縮了其中的 0 個,因此結果有 2 個軸。 形狀是(2,2)
因為它們是a和b 中各個非收縮軸的形狀(按順序)。
np.tensordot([1, 1], [2, 2], axes=1)
所述axes=1
參數可以被轉換為((最后1軸的),(在第一1個軸線B的)),或者在這種情況下((0,), (0,))
這些是將要收縮的軸
所有其他軸都不會收縮。 由於我們已經在收縮每個軸,剩下的軸是((), ())
。
然后tensordot操作如下:
sum( # summing over contraction axis
[x*y for x,y in zip(a, b)] # contracted axes must line up
)
請注意,由於我們正在收縮所有軸,因此結果是一個標量(或 0 形張量)。 在 numpy 中,您只會得到一個形狀為()
的張量,表示 0 軸而不是實際的標量。
np.tensordot([1, 1], [2, 2], axes=2)
這不起作用的原因是因為a和b都沒有兩個單獨的軸可以收縮。
np.tensordot([[1,1],[1,1]], [[2,2],[2,2]], axes=1)
我跳過了你的幾個例子,因為它們並不復雜到比我認為的前幾個更清晰。
在這種情況下, a和b都有兩個可用的軸(讓這個問題更有趣一些),並且它們都有形狀(2,2)
。
所述axes=1
論點仍然代表的最后1軸和b的第一1軸,留給我們((1,), (0,))
這些是將要收縮的軸。
其余的軸不收縮,並有助於最終解決方案的形狀。 它們是((0,), (1,))
。
然后我們可以構建 tensordot 操作。 為了便於論證,假設a和b是 numpy 數組,以便我們可以使用數組屬性並使問題更清晰(例如b=np.array([[2,2],[2,2]])
)。
[
[
sum( # summing the contracted indices
[x*y for x,y in zip(v,w)] # axis 1 of a and axis 0 of b must line up for the summation
)
for w in b.T # iterating over axis 1 of b (i.e. the columns)
]
for v in a # iterating over axis 0 of a (i.e. the rows)
]
結果具有形狀(a.shape[0], b.shape[1])
因為這些是非收縮軸。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.