簡體   English   中英

numpy.tensordot 函數是如何逐步工作的?

[英]How does numpy.tensordot function works step-by-step?

我是 numpy 的新手,所以我在可視化numpy.tensordot()函數的工作時numpy.tensordot()一些問題。 根據tensordot的文檔,軸在參數中傳遞,其中軸=0 或 1 表示正常矩陣乘法,而軸=2 表示收縮。

有人可以解釋一下乘法將如何處理給定的例子嗎?

示例 1: a=[1,1] b=[2,2] for axes=0,1為什么它會在 axes=2 時引發錯誤?
示例 2: a=[[1,1],[1,1]] b=[[2,2],[2,2]] for axes=0,1,2

編輯:此答案的最初重點是axes是元組的情況,為每個參數指定一個或多個軸。 這種用法允許我們對傳統的dot進行變體,特別是對於大於 2d 的數組(我在鏈接問題中的回答也是, https://stackoverflow.com/a/41870980/901925 )。 作為標量的軸是一種特殊情況,它被翻譯成元組版本。 所以它的核心仍然是一個dot積。

軸作為元組

In [235]: a=[1,1]; b=[2,2]

ab是列表; tensordot將它們變成數組。

In [236]: np.tensordot(a,b,(0,0))
Out[236]: array(4)

由於它們都是一維數組,我們將軸值指定為 0。

如果我們嘗試指定 1:

In [237]: np.tensordot(a,b,(0,1))
---------------------------------------------------------------------------
   1282     else:
   1283         for k in range(na):
-> 1284             if as_[axes_a[k]] != bs[axes_b[k]]:
   1285                 equal = False
   1286                 break

IndexError: tuple index out of range

它正在檢查a的軸 0 的大小是否與b的軸 1 的大小匹配。 但由於b是 1d,它無法檢查。

In [239]: np.array(a).shape[0]
Out[239]: 2
In [240]: np.array(b).shape[1]
IndexError: tuple index out of range

你的第二個例子是二維數組:

In [242]: a=np.array([[1,1],[1,1]]); b=np.array([[2,2],[2,2]])

指定的最后一個軸a和第一b (第二到最后),產生傳統的矩陣(點)產物:

In [243]: np.tensordot(a,b,(1,0))
Out[243]: 
array([[4, 4],
       [4, 4]])
In [244]: a.dot(b)
Out[244]: 
array([[4, 4],
       [4, 4]])

更好的診斷值:

In [250]: a=np.array([[1,2],[3,4]]); b=np.array([[2,3],[2,1]])
In [251]: np.tensordot(a,b,(1,0))
Out[251]: 
array([[ 6,  5],
       [14, 13]])
In [252]: np.dot(a,b)
Out[252]: 
array([[ 6,  5],
       [14, 13]])

In [253]: np.tensordot(a,b,(0,1))
Out[253]: 
array([[11,  5],
       [16,  8]])
In [254]: np.dot(b,a)      # same numbers, different layout
Out[254]: 
array([[11, 16],
       [ 5,  8]])
In [255]: np.dot(b,a).T
Out[255]: 
array([[11,  5],
       [16,  8]])

另一個配對:

In [256]: np.tensordot(a,b,(0,0))
In [257]: np.dot(a.T,b)

(0,1,2) 軸是完全錯誤的。 軸參數應該是 2 個數字或 2 個元組,對應於 2 個參數。

tensordot的基本處理是對輸入進行轉置和重塑,以便它可以將結果傳遞給np.dot用於常規(a 的最后一個,b 的最后一個)矩陣乘積。

軸作為標量

如果我對tensordot代碼的閱讀是正確的,則axes參數將轉換為兩個列表:

def foo(axes):
    try:
        iter(axes)
    except Exception:
        axes_a = list(range(-axes, 0))
        axes_b = list(range(0, axes))
    else:
        axes_a, axes_b = axes
    try:
        na = len(axes_a)
        axes_a = list(axes_a)
    except TypeError:
        axes_a = [axes_a]
        na = 1
    try:
        nb = len(axes_b)
        axes_b = list(axes_b)
    except TypeError:
        axes_b = [axes_b]
        nb = 1

    return axes_a, axes_b

對於標量值 0,1,2,結果為:

In [281]: foo(0)
Out[281]: ([], [])
In [282]: foo(1)
Out[282]: ([-1], [0])
In [283]: foo(2)
Out[283]: ([-2, -1], [0, 1])

axes=1與在元組中指定相同:

In [284]: foo((-1,0))
Out[284]: ([-1], [0])

對於 2:

In [285]: foo(((-2,-1),(0,1)))
Out[285]: ([-2, -1], [0, 1])

在我的最新示例中, axes=2與在 2 個數組的所有軸上指定一個dot相同:

In [287]: np.tensordot(a,b,axes=2)
Out[287]: array(18)
In [288]: np.tensordot(a,b,axes=((0,1),(0,1)))
Out[288]: array(18)

這與在數組的扁平化 1d 視圖上做dot相同:

In [289]: np.dot(a.ravel(), b.ravel())
Out[289]: 18

我已經演示了這些數組的傳統點積, axes=1情況。

axes=0axes=((),()) ,兩個數組沒有求和軸:

In [292]: foo(((),()))
Out[292]: ([], [])

np.tensordot(a,b,((),()))np.tensordot(a,b,axes=0)

當輸入數組為 1d 時, foo(2)轉換中的-2會給您帶來問題。 axes=1是一維數組的“收縮”。 換句話說,不要太字面理解文檔中的描述。 他們只是試圖描述代碼的動作; 它們不是正式的規范。

等價物

我認為einsum的軸規格更清晰、更強大。 這是 0,1,2 的等價物

In [295]: np.einsum('ij,kl',a,b)
Out[295]: 
array([[[[ 2,  3],
         [ 2,  1]],

        [[ 4,  6],
         [ 4,  2]]],


       [[[ 6,  9],
         [ 6,  3]],

        [[ 8, 12],
         [ 8,  4]]]])
In [296]: np.einsum('ij,jk',a,b)
Out[296]: 
array([[ 6,  5],
       [14, 13]])
In [297]: np.einsum('ij,ij',a,b)
Out[297]: 18

axis=0 的情況,相當於:

np.dot(a[:,:,None],b[:,None,:])

它添加了一個新的最后一個軸和新的第二個到最后一個軸,並對它們進行傳統的點積求和。 但是我們通常用廣播做這種“外部”乘法:

a[:,:,None,None]*b[None,None,:,:]

雖然對軸使用 0,1,2 很有趣,但它實際上並沒有增加新的計算能力。 軸的元組形式更強大和有用。

代碼總結(大步驟)

1 - 將axes轉換為axes_aaxes_b如上述foo函數中摘錄的

2 - 將ab組成數組,並獲得形狀和 ndim

3 - 檢查將相加的軸上的匹配大小(收縮)

4 - 構造一個newshape_anewaxes_a b相同(復雜步驟)

5 - at = a.transpose(newaxes_a).reshape(newshape_a) ; b

6 - res = dot(at, bt)

7 - 將res重塑為所需的返回形狀

5和6是計算核心。 4 是概念上最復雜的步驟。 對於所有axes值,計算都是相同的, dot積,但設置不同。

超過 0,1,2

雖然文檔只提到了標量軸的 0,1,2,但代碼不限於這些值

In [331]: foo(3)
Out[331]: ([-3, -2, -1], [0, 1, 2])

如果輸入為 3,則軸 = 3 應該可以工作:

In [330]: np.tensordot(np.ones((2,2,2)), np.ones((2,2,2)), axes=3)
Out[330]: array(8.)

或更一般地說:

In [325]: np.tensordot(np.ones((2,2,2)), np.ones((2,2,2)), axes=0).shape
Out[325]: (2, 2, 2, 2, 2, 2)
In [326]: np.tensordot(np.ones((2,2,2)), np.ones((2,2,2)), axes=1).shape
Out[326]: (2, 2, 2, 2)
In [327]: np.tensordot(np.ones((2,2,2)), np.ones((2,2,2)), axes=2).shape
Out[327]: (2, 2)
In [328]: np.tensordot(np.ones((2,2,2)), np.ones((2,2,2)), axes=3).shape
Out[328]: ()

如果輸入為 0d,則 axes=0 有效(axes = 1 無效):

In [335]: np.tensordot(2,3, axes=0)
Out[335]: array(6)

你能解釋一下嗎?

In [363]: np.tensordot(np.ones((4,2,3)),np.ones((2,3,4)),axes=2).shape
Out[363]: (4, 4)

我已經嘗試過 3d 數組的其他標量軸值。 雖然可以提出有效的形狀對,但更明確的元組軸值更容易使用。 0,1,2選項是僅適用於特殊情況的捷徑。 元組方法更容易使用 - 盡管我仍然更喜歡einsum表示法。

示例 1-0: np.tensordot([1, 1], [2, 2], axes=0)

在這種情況下, ab都具有單個軸並具有形狀(2,)

所述axes=0參數可以被轉換為((最后0軸的一個),(第一個0軸線B的)),或者在這種情況下((), ()) 這些是將要收縮的軸。

所有其他軸都不會收縮。 由於ab 中的一個都有第 0 個軸而沒有其他軸,因此這些軸是((0,), (0,))

然后tensordot操作如下(大致):

[
    [x*y for y in b]  # all the non-contraction axes in b
    for x in a        # all the non-contraction axes in a
]

請注意,由於ab之間共有 2 個可用軸,並且由於我們收縮了其中的 0 個,因此結果有 2 個軸。 形狀是(2,2)因為它們是ab 中各個非收縮軸的形狀(按順序)。

示例 1-1: np.tensordot([1, 1], [2, 2], axes=1)

所述axes=1參數可以被轉換為((最后1),(在第一1個軸線B的)),或者在這種情況下((0,), (0,)) 這些是將要收縮的軸

所有其他軸都不會收縮。 由於我們已經在收縮每個軸,剩下的軸是((), ())

然后tensordot操作如下:

sum(  # summing over contraction axis
    [x*y for x,y in zip(a, b)]  # contracted axes must line up
)

請注意,由於我們正在收縮所有軸,因此結果是一個標量(或 0 形張量)。 在 numpy 中,您只會得到一個形狀為()的張量,表示 0 軸而不是實際的標量。

示例 1-2: np.tensordot([1, 1], [2, 2], axes=2)

這不起作用的原因是因為ab都沒有兩個單獨的軸可以收縮。

示例 2-1: np.tensordot([[1,1],[1,1]], [[2,2],[2,2]], axes=1)

我跳過了你的幾個例子,因為它們並不復雜到比我認為的前幾個更清晰。

在這種情況下, ab都有兩個可用的軸(讓這個問題更有趣一些),並且它們都有形狀(2,2)

所述axes=1論點仍然代表最后1軸和b的第一1軸,留給我們((1,), (0,)) 這些是將要收縮的軸。

其余的軸不收縮,並有助於最終解決方案的形狀。 它們是((0,), (1,))

然后我們可以構建 tensordot 操作。 為了便於論證,假設ab是 numpy 數組,以便我們可以使用數組屬性並使問題更清晰(例如b=np.array([[2,2],[2,2]]) )。

[
    [
        sum(  # summing the contracted indices
            [x*y for x,y in zip(v,w)]  # axis 1 of a and axis 0 of b must line up for the summation
        )
        for w in b.T  # iterating over axis 1 of b (i.e. the columns)
    ]
    for v in a  # iterating over axis 0 of a (i.e. the rows)
]

結果具有形狀(a.shape[0], b.shape[1])因為這些是非收縮軸。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM