简体   繁体   中英

Black voodoo of NumPy Einsum

I got some working code using einsum function. But as einsum is currently still like black voodoo for me. I was wondering, what this code actually is doing and if it can be somehow optimized using np.dot

My data looks likes this

n, p, q = 40000, 8, 4
a = np.random.rand(n, p, q)
b = np.random.rand(n, p)

And my existing functions einsum functions looks like this

f1 = np.einsum("ijx,ijy->ixy", a, a)
f2 = np.einsum("ijx,ij->ix", a, b)

But what does it really do? I get till here: each dimension (axis) is represented by a label, i is equal to the first axis n , j for the 2nd axis p and x and y are different labels for the same axis q . So the order of the output array of f1 is ixy and thus the output shape is 40000,4,4 (n,q,q)

But that's as far as I get. And

Lets play around with a couple of small arrays

In [110]: a=np.arange(2*3*4).reshape(2,3,4)

In [111]: b=np.arange(2*3).reshape(2,3)

In [112]: np.einsum('ijx,ij->ix',a,b)
Out[112]: 
array([[ 20,  23,  26,  29],
       [200, 212, 224, 236]])

In [113]: np.diagonal(np.dot(b,a)).T
Out[113]: 
array([[ 20,  23,  26,  29],
       [200, 212, 224, 236]])

np.dot operates on the last dim of the 1st array, and 2nd to the last of the 2nd. So I have to switch the arguments so the 3 dimension lines up. dot(b,a) produces a (2,2,4) array. diagonal selects 2 of those 'rows', and transpose to clean up. Another einsum expresses that cleanup nicely:

In [122]: np.einsum('iik->ik',np.dot(b,a))

Since np.dot is producing a larger array than the original einsum , it is unlikely to be faster, even if the underlying C code is tighter.

(Curiously I'm having trouble replicating np.dot(b,a) with einsum ; it won't generate that (2,2,...) array).

For the a,a case we have to do something similar - roll the axes of one array so the last dimension lines up with the 2nd to last of the other, do the dot , and then cleanup with diagonal and transpose :

In [157]: np.einsum('ijx,ijy->ixy',a,a).shape
Out[157]: (2, 4, 4)
In [158]: np.einsum('ijjx->jix',np.dot(np.rollaxis(a,2),a))
In [176]: np.diagonal(np.dot(np.rollaxis(a,2),a),0,2).T

tensordot is another way of taking a dot over selected axes.

np.tensordot(a,a,(1,1))
np.diagonal(np.rollaxis(np.tensordot(a,a,(1,1)),1),0,2).T  # with cleanup

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM