![](/img/trans.png)
[英]numpy einsum: Elementwise product between 3D matrix and 2D matrix
[英]Numpy dot product between a 3d matrix and 2d matrix
我有一個形狀為(2, 10, 3)
的 3d 數組和一個形狀為(2, 3)
) 的二維數組,如下所示:
print(t) #2d array
Output:
[[1.003 2.32 3.11 ]
[1.214 5.32 2.13241]]
print(normal) #3d array
Output:
[[[0.69908573 0.0826756 0.84485978]
[0.51058213 0.4052637 0.5068118 ]
[0.45974276 0.25819549 0.10780089]
[0.27484999 0.33367648 0.128262 ]
[0.35963389 0.77600065 0.89393939]
[0.46937506 0.59291623 0.06620307]
[0.87603987 0.44414505 0.83394174]
[0.83186093 0.62491876 0.38160734]
[0.96819897 0.80183442 0.75102768]
[0.54182908 0.19403844 0.07925769]]
[[2.82248573 3.2341756 0.96825978]
[2.63398213 3.5567637 0.6302118 ]
[2.58314276 3.40969549 0.23120089]
[2.39824999 3.48517648 0.251662 ]
[2.48303389 3.92750065 1.01733939]
[2.59277506 3.74441623 0.18960307]
[2.99943987 3.59564505 0.95734174]
[2.95526093 3.77641876 0.50500734]
[3.09159897 3.95333442 0.87442768]
[2.66522908 3.34553844 0.20265769]]]
如何獲取 2d 數組t
中的每一行以獲取 3d 數組normal
中的相應點積,以便數組最終得到一個形狀(2, 10)
,其中每個包含 2d 中第 n 行之間的所有 10 個點積3d 數組中的數組和第 n 個矩陣?
[0.62096458 0.62618459 0.37528887 0.5728386 1.19634398 0.79620507
1.997884 0.75229492 1.2236496 0.4210626 ]
[2.96347746 3.30738892 3.50596579 4.93082295 5.33811805 4.44872493
7.33480393 4.19173472 4.7406248 7.83229689]
您可以使用numpy.einsum
獲得此結果:
import numpy as np
normal = np.array([
[1.003,2.32,3.11],
[1.214,5.32,2.13241]
])
t = np.array([
[
[0.69908573, 0.0826756, 0.84485978],
[0.51058213, 0.4052637, 0.5068118 ],
[0.45974276, 0.25819549, 0.10780089],
[0.27484999, 0.33367648, 0.128262 ],
[0.35963389, 0.77600065, 0.89393939],
[0.46937506, 0.59291623, 0.06620307],
[0.87603987, 0.44414505, 0.83394174],
[0.83186093, 0.62491876, 0.38160734],
[0.96819897, 0.80183442, 0.75102768],
[0.54182908, 0.19403844, 0.07925769]
],
[
[2.82248573, 3.2341756, 0.96825978],
[2.63398213, 3.5567637, 0.6302118 ],
[2.58314276, 3.40969549, 0.23120089],
[2.39824999, 3.48517648, 0.251662 ],
[2.48303389, 3.92750065, 1.01733939],
[2.59277506, 3.74441623, 0.18960307],
[2.99943987, 3.59564505, 0.95734174],
[2.95526093, 3.77641876, 0.50500734],
[3.09159897, 3.95333442, 0.87442768],
[2.66522908, 3.34553844, 0.20265769]
]
])
np.einsum('ijk,ik->ij', t, normal)
這導致
array([[ 3.52050429, 3.02851036, 1.39539629, 1.44869879, 4.9411858 ,
2.05224039, 4.50264332, 3.47096686, 5.16705551, 1.24011516],
[22.69703871, 23.46350713, 21.76853041, 21.98926093, 26.07809129,
23.47223475, 24.81159677, 24.75511727, 26.64957859, 21.46600189]])
這與按順序進行兩次乘法相同:
t[0] @ normal[0]
t[1] @ normal[1]
給出兩個:
array([3.52050429, 3.02851036, 1.39539629, 1.44869879, 4.9411858 ,
2.05224039, 4.50264332, 3.47096686, 5.16705551, 1.24011516])
array([22.69703871, 23.46350713, 21.76853041, 21.98926093, 26.07809129,
23.47223475, 24.81159677, 24.75511727, 26.64957859, 21.46600189])
我認為np.tensordot
您的要求,盡管結果與您的不同。
import numpy as np
t = np.array(
[[1.003, 2.32, 3.11 ],
[1.214, 5.32, 2.13241]]
)
normal = np.array(
[[[0.69908573, 0.0826756, 0.84485978],
[0.51058213, 0.4052637, 0.5068118, ],
[0.45974276, 0.25819549, 0.10780089],
[0.27484999, 0.33367648, 0.128262, ],
[0.35963389, 0.77600065, 0.89393939],
[0.46937506, 0.59291623, 0.06620307],
[0.87603987, 0.44414505, 0.83394174],
[0.83186093, 0.62491876, 0.38160734],
[0.96819897, 0.80183442, 0.75102768],
[0.54182908, 0.19403844, 0.07925769]],
[[2.82248573, 3.2341756, 0.96825978],
[2.63398213, 3.5567637, 0.6302118, ],
[2.58314276, 3.40969549, 0.23120089],
[2.39824999, 3.48517648, 0.251662, ],
[2.48303389, 3.92750065, 1.01733939],
[2.59277506, 3.74441623, 0.18960307],
[2.99943987, 3.59564505, 0.95734174],
[2.95526093, 3.77641876, 0.50500734],
[3.09159897, 3.95333442, 0.87442768],
[2.66522908, 3.34553844, 0.20265769]]]
)
print(normal)
t = np.tensordot( normal, t, axes=([2],[1]))
print(t)
Output:
[[[0.69908573 0.0826756 0.84485978]
[0.51058213 0.4052637 0.5068118 ]
[0.45974276 0.25819549 0.10780089]
[0.27484999 0.33367648 0.128262 ]
[0.35963389 0.77600065 0.89393939]
[0.46937506 0.59291623 0.06620307]
[0.87603987 0.44414505 0.83394174]
[0.83186093 0.62491876 0.38160734]
[0.96819897 0.80183442 0.75102768]
[0.54182908 0.19403844 0.07925769]]
[[2.82248573 3.2341756 0.96825978]
[2.63398213 3.5567637 0.6302118 ]
[2.58314276 3.40969549 0.23120089]
[2.39824999 3.48517648 0.251662 ]
[2.48303389 3.92750065 1.01733939]
[2.59277506 3.74441623 0.18960307]
[2.99943987 3.59564505 0.95734174]
[2.95526093 3.77641876 0.50500734]
[3.09159897 3.95333442 0.87442768]
[2.66522908 3.34553844 0.20265769]]]
[[[ 3.52050429 3.09011171]
[ 3.02851036 3.85658014]
[ 1.39539629 2.16160341]
[ 1.44869879 2.38233393]
[ 4.9411858 6.4711643 ]
[ 2.05224039 3.86530775]
[ 4.50264332 5.20466977]
[ 3.47096686 5.14819028]
[ 5.16705551 7.0426516 ]
[ 1.24011516 1.85907489]]
[[13.34552849 22.69703871]
[12.85353456 23.46350713]
[11.22042049 21.76853041]
[11.27372299 21.98926093]
[14.76621 26.07809129]
[11.87726459 23.47223475]
[14.32766752 24.81159677]
[13.29599106 24.75511727]
[14.99207971 26.64957859]
[11.06513936 21.46600189]]]
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.