简体   繁体   中英

Dot product dimension issue with theano

This is my first question on stackoverflow. Hopes it fits the expected format.

I'm simply trying to make a function that computes a dot product between a matrix and a vector:

A = T.matrix()
B = T.vector()
C = T.dot(A,B)

a = numpy.array([[1,2],[3,4]])
b = numpy.array([[1],[0]])

f = function([A,B],C)

When I ask the product to numpy, I have no issue. numpy.dot(a,b) returns my expected array([[1],[3]])

But when I ask f(a,b) to Theano, I've got this error: TypeError: ('Bad input argument to theano function with name "<stdin>:1" at index 1(0-based)', 'Wrong number of dimensions: expected 1, got 2 with shape (2, 1).')

My understanding was that Theano typing expected my Y declared as vector to necessarily be in 1D. I've therefore tried with b = numpy.array([1,0]) instead. Maybe theano will understand that it has to be a column to be able to compute ? Now the result I get is array([[ 2., 2.],[ 4., 4.]]) , which is even more of a mystery to me.

I've find similar topics around but none with an explanation I could grasp. I realize it must appear obvious for many of you, but I'd appreciate a hand !

When you type the following

b = numpy.array([1,0])

it internally creates is a column matrix of shape Rx1. That's how vectors are represented generally.

Now when you type in,

b = numpy.array([[1],[0]])

And then pass it to function f, it throws an error because it is expecting a Vector and you are passing a Matrix as the second parameter. Change B to type matrix, and it works fine.

So, either do this :-

A = T.matrix() 
B = T.matrix()
C = T.dot(A,B)

a = numpy.array([[1,2],[3,4]])
b = numpy.array([[1], [0]])

f = function([A,B],C)

print f(a,b)

Or you can treat B as a vector itself, and pass a vector which would be

A = T.matrix()
B = T.vector()
C = T.dot(A,B)

a = numpy.array([[1,2],[3,4]])
b = numpy.array([1,0])

f = function([A,B],C)

print f(a,b)

Hope this helped!

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM