简体   繁体   中英

Numpy array dot product

We all know that dot product between vectors must return a scalar:

import numpy as np
a = np.array([1,2,3])
b = np.array([3,4,5])
print(a.shape) # (3,)
print(b.shape) # (3,)
a.dot(b) # 26
b.dot(a) # 26

perfect. BUT WHY if we use a "real" (take a look at Difference between numpy.array shape (R, 1) and (R,) ) row vector or column vector the numpy dot product returns error on dimension ?

arow = np.array([[1,2,3]])
brow = np.array([[3,4,5]])
print(arow.shape) # (1,3)
print(brow.shape) # (1,3)
arow.dot(brow) # ERROR
brow.dot(arow) # ERROR

acol = np.array([[1,2,3]]).reshape(3,1)
bcol = np.array([[3,4,5]]).reshape(3,1)
print(acol.shape) # (3,1)
print(bcol.shape) # (3,1)
acol.dot(bcol) # ERROR
bcol.dot(acol) # ERROR

Because by explicitly adding a second dimension, you are no longer working with vectors but with two dimensional matrices. When taking the dot product of matrices, the inner dimensions of the product must match.

You therefore need to transpose one of your matrices. Which one you transpose will determine the meaning and shape of the result.

A 1x3 times a 3x1 matrix will result in a 1x1 matrix (ie, a scalar). This is the inner product. A 3x1 times a 1x3 matrix will result in a 3x3 outer product.

You can also use the @ operator, which is actually matrix multiplication. In this case, as well as in dot product, you need to be aware to the matrices sizes ( ndarray should always be dim compatible ), but it's more readable:

>>> a = np.array([1,2,3])
>>> a.shape
(3,)
>>> b= np.array([[1,2,3]])
>>> b.shape
(1, 3)
>>> a@b
Traceback (most recent call last):
  File "<input>", line 1, in <module>
ValueError: shapes (3,) and (1,3) not aligned: 3 (dim 0) != 1 (dim 0)
>>> a@b.T
array([14])

You can also do like this

import numpy as npy
Vector1 = npy.array([0,2,3])
Vector2 = npy.array([3,5,1])
print("Dot Product of", Vector1, "and", Vector2,)
def DotProduct(a,b):
  NetValue = 0
  for i in range(len(a)):
   NetValue += a[i]*b[i]
 return NetValue
ans = DotProduct(Vector1,Vector2)
print("The answer is =",ans)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM