简体   繁体   中英

How to use tensor as a input of a function?(tensorflow)

I try to use tensorflow to construct my model to solve a differential equations, for example,

dX/dt=f(\mu,X,t)

Here, \\mu is a function depends on X, which is complex so that I want to predict \\mu(X) using neural net.

First, my input, X, passes a dense layer N to get \\mu~N(X). Then, I solve the ODE above using Runge-Kutta method, which is defined by a code:

def RK4(self, mu, X, t, dt=0.2):
    kX1=dt*self.f(mu, X, t)
    kX2=dt*self.f(mu, X+kX1/2, t+dt/2)
    kX3=dt*self.f(mu, X+kX2/2, t+dt/2)
    kX4=dt*self.f(mu, X+kX3, t+dt)
    X_next=X+(kX1+2*kX2+2*kX3+kX4)/6

    return X_next

Note that self comes from a class variable. When I directly put N(X) into RK4, an error occurs.

 Tensor objects are only iterable when eager execution is enabled. To iterate 
 over this tensor use tf.map_fn.

I'm not familiar with this map_fn. My function is complicated because it has both tensor(\\mu, X) and float(t, dt). But as I know, map_fn only deals with a tensor input. Is there a smart way to deals with these inputs? Thanks!

X_next= tf.map_fn(lambda x : self.RK4(x[0],x[1],x[2]),(self.mu, self.X, self.t), dtype=tf.float32)

will solve my problem. In fact, tf.map_fn can receive either tensor type inputs or float type inputs. Usage of such a function can be seen from this link

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM