简体   繁体   中英

How numpy.tensordot command works?and what is the meaning of summing over axis in this command?

I am trying to understand how numpy.tensordot command works. I go through different questions posted on this forum related to this command. Moreover axes (1,0) indicate that axes 1 in a and axes 0 in b will be summed over. So I sum the terms along axes 1 in a and along zero axis in b and calculated answer manually but the result is different. May be my understanding about summed along a particular axis is wrong. Can someone please explain how we are getting final result in this following code?

a = numpy.array([[1,2],[3,4]])
b = numpy.array([[0,5],[-1,20]])

c = numpy.tensordot(a,b,axes=(1,0))

print(c)
print("result")
[[-2 45]
[-4 95]]
In [432]: a=np.array([[1,2],[3,4]]); b=np.array([[0,5],[-1,20]])                
In [433]: np.tensordot(a,b,axes=(1,0))                                          
Out[433]: 
array([[-2, 45],
       [-4, 95]])

The (1,0) means axis 1 of a and axis 0 of b are the sum-of-products axes. That's just the normal np.dot pairing:

In [434]: np.dot(a,b)                                                           
Out[434]: 
array([[-2, 45],
       [-4, 95]])

I find the einsum notation to be clearer:

In [435]: np.einsum('ij,jk->ik',a,b)                                            
Out[435]: 
array([[-2, 45],
       [-4, 95]])

In any case this is matrix product we learned in school - run your finger across the rows of a , and the down the columns of b .

[[1*0+2*-1, 1*5+2*20], ...]  

Yet another expression - expanding from the einsum one:

In [440]: (a[:,:,None]*b[None,:,:]).sum(axis=1)                                 
Out[440]: 
array([[-2, 45],
       [-4, 95]])

tensordot reshapes and transposes axes, aiming to reduce the problem to a simple call to np.dot . It then reshapes/transposes back as needed. The details depend on the axes parameters. In your case no reshaping is needed, since your specification matches the default dot action.

A tuple axes parameter is relatively easy to explain. There is also a scalar axis case (0,1,2 etc), that's a bit trickier. I've explored that in another post.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM