简体   繁体   中英

TensorFlow: integrate output of neural network

I have a neural network that takes as input two parameters:

t = tf.placeholder(tf.float32, [None, 1])
x = tf.placeholder(tf.float32, [None, 1])

in my loss function I need to integrate the output over t , but I can't figure out a way of doing this, because the only numerical integration function available in TensorFlow, tf.contrib.integrate.odeint_fixed , cannot take a Tensor as function, since it can't be called:

Call

t = tf.constant(np.linspace(0.0,1.0,100), dtype = tf.float64 )

integ = tf.contrib.integrate.odeint_fixed(model.output, 
                                          0.0, 
                                          t, 
                                          method = "rk4")

Output

...

<ipython-input-5-c79e79b75391> in loss(model, t, x)
     24                                                 0.0,
     25                                                 t,
---> 26                                                 method = "rk4")

...

TypeError: 'Tensor' object is not callable

Not to mention that I'm also clueless on how to treat x in this computation, it's supposed to be hold fixed.

tf.contrib.integrate.odeint_fixed appears to be for integrating ordinary differential equations (ODEs). If I understand you correctly, however, you want to approximate the definite integral of your model's output, let's call it y , sampled at t .

To do so, you could use the trapezoidal rule , for which you find a possible implementation in tensorflows AUC function . In your case, it could look like:

from tensorflow.python.ops import math_ops

def trapezoidal_integral_approx(t, y):
    return math_ops.reduce_sum(
            math_ops.multiply(t[:-1] - t[1:],
                              (y[:-1] + y[1:]) / 2.), 
            name='trapezoidal_integral_approx')

where y would be the output of your model.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM