简体   繁体   中英

matrix multiply every pair of 2-D arrays along first dimension with einsum

I have two 3-D arrays of the same size a and b

np.random.seed([3,14159])
a = np.random.randint(10, size=(4, 3, 2))
b = np.random.randint(10, size=(4, 3, 2))

print(a)

[[[4 8]
  [1 1]
  [9 2]]

 [[8 1]
  [4 2]
  [8 2]]

 [[8 4]
  [9 4]
  [3 4]]

 [[1 5]
  [1 2]
  [6 2]]]

print(b)

[[[7 7]
  [1 1]
  [7 8]]

 [[7 4]
  [8 0]
  [0 9]]

 [[3 8]
  [7 7]
  [2 6]]

 [[3 1]
  [9 3]
  [0 5]]]

I want to take the first array from a

a[0]

[[4 8]
 [1 1]
 [9 2]]

And the first one from b

b[0]

[[7 7]
 [1 1]
 [7 8]]

And return this

a[0].T.dot(b[0])

[[ 92 101]
 [ 71  73]]

But I want to do this over the entire first dimension. I thought I could use np.einsum

np.einsum('abc,ade->ace', a, b)

[[[210 224]
  [165 176]]

 [[300 260]
  [ 75  65]]

 [[240 420]
  [144 252]]

 [[ 96  72]
  [108  81]]]

This is the correct shape, but not values.

I expect to get this:

np.array([x.T.dot(y).tolist() for x, y in zip(a, b)])

[[[ 92 101]
  [ 71  73]]

 [[ 88 104]
  [ 23  22]]

 [[ 93 145]
  [ 48  84]]

 [[ 12  34]
  [ 33  21]]]

The matrix multiplication amounts to a sum of products where the sum is taken over the middle axis, so the index b should be the same for both arrays: (ie change ade to abe ):

In [40]: np.einsum('abc,abe->ace', a, b)
Out[40]: 
array([[[ 92, 101],
        [ 71,  73]],

       [[ 88, 104],
        [ 23,  22]],

       [[ 93, 145],
        [ 48,  84]],

       [[ 12,  34],
        [ 33,  21]]])

When the input arrays have index subscripts that are missing in the output array, they are summed over independently. That is,

np.einsum('abc,ade->ace', a, b)

is equivalent to

In [44]: np.einsum('abc,ade->acebd', a, b).sum(axis=-1).sum(axis=-1)
Out[44]: 
array([[[210, 224],
        [165, 176]],

       [[300, 260],
        [ 75,  65]],

       [[240, 420],
        [144, 252]],

       [[ 96,  72],
        [108,  81]]])

Here's one with np.matmul as we need to push back the second axis of a to the end, so that it would get sum-reduced against the second axis from b , while keeping their first axes aligned -

np.matmul(a.swapaxes(1,2),b) 

Schematically put :

At start :

a   :  M x N X R1
b   :  M x N X R2

With swapped axes for a :

a   :  M x R1 X [N]
b   :  M x [N] X R2

The bracketed axes get sum-reduced , leaving us with :

out :   M x R1 X R2

On Python 3.x, matmul is taken care of with @ operator -

a.swapaxes(1,2) @ b

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM