簡體   English   中英

Numpy 3d 矩陣和二維矩陣之間的點積

[英]Numpy dot product between a 3d matrix and 2d matrix

我有一個形狀為(2, 10, 3)的 3d 數組和一個形狀為(2, 3) ) 的二維數組,如下所示:

print(t) #2d array

Output:

[[1.003   2.32    3.11   ]
 [1.214   5.32    2.13241]]

print(normal) #3d array

Output:

[[[0.69908573 0.0826756  0.84485978]
  [0.51058213 0.4052637  0.5068118 ]
  [0.45974276 0.25819549 0.10780089]
  [0.27484999 0.33367648 0.128262  ]
  [0.35963389 0.77600065 0.89393939]
  [0.46937506 0.59291623 0.06620307]
  [0.87603987 0.44414505 0.83394174]
  [0.83186093 0.62491876 0.38160734]
  [0.96819897 0.80183442 0.75102768]
  [0.54182908 0.19403844 0.07925769]]

 [[2.82248573 3.2341756  0.96825978]
  [2.63398213 3.5567637  0.6302118 ]
  [2.58314276 3.40969549 0.23120089]
  [2.39824999 3.48517648 0.251662  ]
  [2.48303389 3.92750065 1.01733939]
  [2.59277506 3.74441623 0.18960307]
  [2.99943987 3.59564505 0.95734174]
  [2.95526093 3.77641876 0.50500734]
  [3.09159897 3.95333442 0.87442768]
  [2.66522908 3.34553844 0.20265769]]]

如何獲取 2d 數組t中的每一行以獲取 3d 數組normal中的相應點積,以便數組最終得到一個形狀(2, 10) ,其中每個包含 2d 中第 n 行之間的所有 10 個點積3d 數組中的數組和第 n 個矩陣?

[0.62096458 0.62618459 0.37528887 0.5728386  1.19634398 0.79620507
 1.997884   0.75229492 1.2236496  0.4210626 ]
[2.96347746 3.30738892 3.50596579 4.93082295 5.33811805 4.44872493
 7.33480393 4.19173472 4.7406248  7.83229689]

您可以使用numpy.einsum獲得此結果:

import numpy as np

normal = np.array([
    [1.003,2.32,3.11],
    [1.214,5.32,2.13241]
])


t = np.array([
    [ 
        [0.69908573, 0.0826756,  0.84485978],
        [0.51058213, 0.4052637,  0.5068118 ],
        [0.45974276, 0.25819549, 0.10780089],
        [0.27484999, 0.33367648, 0.128262  ],
        [0.35963389, 0.77600065, 0.89393939],
        [0.46937506, 0.59291623, 0.06620307],
        [0.87603987, 0.44414505, 0.83394174],
        [0.83186093, 0.62491876, 0.38160734],
        [0.96819897, 0.80183442, 0.75102768],
        [0.54182908, 0.19403844, 0.07925769]
    ],

    [
        [2.82248573, 3.2341756,  0.96825978],
        [2.63398213, 3.5567637,  0.6302118 ],
        [2.58314276, 3.40969549, 0.23120089],
        [2.39824999, 3.48517648, 0.251662  ],
        [2.48303389, 3.92750065, 1.01733939],
        [2.59277506, 3.74441623, 0.18960307],
        [2.99943987, 3.59564505, 0.95734174],
        [2.95526093, 3.77641876, 0.50500734],
        [3.09159897, 3.95333442, 0.87442768],
        [2.66522908, 3.34553844, 0.20265769]
    ]
])

np.einsum('ijk,ik->ij', t, normal)

這導致

array([[ 3.52050429,  3.02851036,  1.39539629,  1.44869879,  4.9411858 ,
         2.05224039,  4.50264332,  3.47096686,  5.16705551,  1.24011516],
       [22.69703871, 23.46350713, 21.76853041, 21.98926093, 26.07809129,
        23.47223475, 24.81159677, 24.75511727, 26.64957859, 21.46600189]])

這與按順序進行兩次乘法相同:

t[0] @ normal[0] 
t[1] @ normal[1] 

給出兩個:

array([3.52050429, 3.02851036, 1.39539629, 1.44869879, 4.9411858 ,
       2.05224039, 4.50264332, 3.47096686, 5.16705551, 1.24011516])
array([22.69703871, 23.46350713, 21.76853041, 21.98926093, 26.07809129,
       23.47223475, 24.81159677, 24.75511727, 26.64957859, 21.46600189])

我認為np.tensordot您的要求,盡管結果與您的不同。

import numpy as np

t = np.array(
[[1.003, 2.32, 3.11 ],
 [1.214, 5.32, 2.13241]]
)

normal = np.array(
[[[0.69908573, 0.0826756, 0.84485978],
  [0.51058213, 0.4052637, 0.5068118, ],
  [0.45974276, 0.25819549, 0.10780089],
  [0.27484999, 0.33367648, 0.128262, ],
  [0.35963389, 0.77600065, 0.89393939],
  [0.46937506, 0.59291623, 0.06620307],
  [0.87603987, 0.44414505, 0.83394174],
  [0.83186093, 0.62491876, 0.38160734],
  [0.96819897, 0.80183442, 0.75102768],
  [0.54182908, 0.19403844, 0.07925769]],

 [[2.82248573, 3.2341756, 0.96825978],
  [2.63398213, 3.5567637, 0.6302118, ],
  [2.58314276, 3.40969549, 0.23120089],
  [2.39824999, 3.48517648, 0.251662, ],
  [2.48303389, 3.92750065, 1.01733939],
  [2.59277506, 3.74441623, 0.18960307],
  [2.99943987, 3.59564505, 0.95734174],
  [2.95526093, 3.77641876, 0.50500734],
  [3.09159897, 3.95333442, 0.87442768],
  [2.66522908, 3.34553844, 0.20265769]]]
)

print(normal)

t = np.tensordot( normal, t, axes=([2],[1]))
print(t)

Output:

[[[0.69908573 0.0826756  0.84485978]
  [0.51058213 0.4052637  0.5068118 ]
  [0.45974276 0.25819549 0.10780089]
  [0.27484999 0.33367648 0.128262  ]
  [0.35963389 0.77600065 0.89393939]
  [0.46937506 0.59291623 0.06620307]
  [0.87603987 0.44414505 0.83394174]
  [0.83186093 0.62491876 0.38160734]
  [0.96819897 0.80183442 0.75102768]
  [0.54182908 0.19403844 0.07925769]]

 [[2.82248573 3.2341756  0.96825978]
  [2.63398213 3.5567637  0.6302118 ]
  [2.58314276 3.40969549 0.23120089]
  [2.39824999 3.48517648 0.251662  ]
  [2.48303389 3.92750065 1.01733939]
  [2.59277506 3.74441623 0.18960307]
  [2.99943987 3.59564505 0.95734174]
  [2.95526093 3.77641876 0.50500734]
  [3.09159897 3.95333442 0.87442768]
  [2.66522908 3.34553844 0.20265769]]]
[[[ 3.52050429  3.09011171]
  [ 3.02851036  3.85658014]
  [ 1.39539629  2.16160341]
  [ 1.44869879  2.38233393]
  [ 4.9411858   6.4711643 ]
  [ 2.05224039  3.86530775]
  [ 4.50264332  5.20466977]
  [ 3.47096686  5.14819028]
  [ 5.16705551  7.0426516 ]
  [ 1.24011516  1.85907489]]

 [[13.34552849 22.69703871]
  [12.85353456 23.46350713]
  [11.22042049 21.76853041]
  [11.27372299 21.98926093]
  [14.76621    26.07809129]
  [11.87726459 23.47223475]
  [14.32766752 24.81159677]
  [13.29599106 24.75511727]
  [14.99207971 26.64957859]
  [11.06513936 21.46600189]]]

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM