简体   繁体   中英

Theano gradient with function on tensors

I have a function that calculates a value of a scalar field on a 3D space, so I feed it 3D tensors for x, y and z coordinates (obtained by numpy.meshgrid) and use elementwise operations everywhere. This works as expected.

Now I need to calculate a gradient of the scalar field. I've been playing around with theano.tensor.grad and theano.tensor.jacobian and I don't understand how a derivative of elementwise operation is supposed to work.

This is a MWE that I don't understand:

import theano.tensor as T 

x, y = T.matrices("xy")

expr = x**2 + y
grad = T.grad(expr[0, 0], x)
print(grad.eval({x: [[1, 2], [1, 2]], y: [[1, 1], [2, 2]]}))

It prints

[[ 2.  0.]
 [ 0.  0.]]

while I would expect

[[ 2.  4.]
 [ 2.  4.]]

I also tried with jacobian:

import theano.tensor as T

x, y = T.matrices("xy")

expr = x**2 + y
grad = T.jacobian(expr.flatten(), x)
print(grad.eval({x: [[1, 2], [1, 2]], y: [[1, 1], [2, 2]]}))

which returns

[[[ 2.  0.]
  [ 0.  0.]]

 [[ 0.  4.]
  [ 0.  0.]]

 [[ 0.  0.]
  [ 2.  0.]]

 [[ 0.  0.]
  [ 0.  4.]]]

(the nonzero elements together would give me my expected matrix from the previous example)

Is there some way to get the elmentwise gradients I need?

Can I for example somehow define the function as scalar (three scalars into a scalar) apply it elementwise over the coordinate tensors? This way the derivative would also be just a simple scalar and everything would work smoothly.

The first element expr[0,0] as a cost with respect to x only relates to the first element of x thus the result you are receiving is correct.

The result you expect is produced if you sum the whole expr array. Theano will take care of backward propagating the gradient through the sum

import theano.tensor as T 

x, y = T.matrices("xy")

expr = x**2 + y
grad = T.grad(expr.sum(), x)
print(grad.eval({x: [[1, 2], [1, 2]], y: [[1, 1], [2, 2]]}))

prints

[[ 2.  4.]
 [ 2.  4.]]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM