I have two 3-D arrays of the same size a
and b
np.random.seed([3,14159])
a = np.random.randint(10, size=(4, 3, 2))
b = np.random.randint(10, size=(4, 3, 2))
print(a)
[[[4 8]
[1 1]
[9 2]]
[[8 1]
[4 2]
[8 2]]
[[8 4]
[9 4]
[3 4]]
[[1 5]
[1 2]
[6 2]]]
print(b)
[[[7 7]
[1 1]
[7 8]]
[[7 4]
[8 0]
[0 9]]
[[3 8]
[7 7]
[2 6]]
[[3 1]
[9 3]
[0 5]]]
I want to take the first array from a
a[0]
[[4 8]
[1 1]
[9 2]]
And the first one from b
b[0]
[[7 7]
[1 1]
[7 8]]
And return this
a[0].T.dot(b[0])
[[ 92 101]
[ 71 73]]
But I want to do this over the entire first dimension. I thought I could use np.einsum
np.einsum('abc,ade->ace', a, b)
[[[210 224]
[165 176]]
[[300 260]
[ 75 65]]
[[240 420]
[144 252]]
[[ 96 72]
[108 81]]]
This is the correct shape, but not values.
I expect to get this:
np.array([x.T.dot(y).tolist() for x, y in zip(a, b)])
[[[ 92 101]
[ 71 73]]
[[ 88 104]
[ 23 22]]
[[ 93 145]
[ 48 84]]
[[ 12 34]
[ 33 21]]]
The matrix multiplication amounts to a sum of products where the sum is taken over the middle axis, so the index b
should be the same for both arrays: (ie change ade
to abe
):
In [40]: np.einsum('abc,abe->ace', a, b)
Out[40]:
array([[[ 92, 101],
[ 71, 73]],
[[ 88, 104],
[ 23, 22]],
[[ 93, 145],
[ 48, 84]],
[[ 12, 34],
[ 33, 21]]])
When the input arrays have index subscripts that are missing in the output array, they are summed over independently. That is,
np.einsum('abc,ade->ace', a, b)
is equivalent to
In [44]: np.einsum('abc,ade->acebd', a, b).sum(axis=-1).sum(axis=-1)
Out[44]:
array([[[210, 224],
[165, 176]],
[[300, 260],
[ 75, 65]],
[[240, 420],
[144, 252]],
[[ 96, 72],
[108, 81]]])
Here's one with np.matmul
as we need to push back the second axis of a
to the end, so that it would get sum-reduced
against the second axis from b
, while keeping their first axes aligned -
np.matmul(a.swapaxes(1,2),b)
Schematically put :
At start :
a : M x N X R1
b : M x N X R2
With swapped axes for a :
a : M x R1 X [N]
b : M x [N] X R2
The bracketed axes get sum-reduced
, leaving us with :
out : M x R1 X R2
On Python 3.x, matmul
is taken care of with @ operator
-
a.swapaxes(1,2) @ b
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.