简体   繁体   中英

How to use tf.while_loop for nested FOR loop

I am working on a model which I implemented using tensorflow. Everything was going well till I had to implement a nested function (something like)

方程

using a nested FOR loop. For instance something like below:

for k in range(n):
    Gk = tf.convert_to_tensor(G[k], dtype=self.dtype)
    Gk = tf.tile(tf.expand_dims(Gk,0),[tf.shape(trueX)[0],1])

    for j in range(m):
        pre_Y  = self.build_nn( Gk, 200, '{}phi2'.format(j), tf.AUTO_REUSE )
        log_pred_Y = tf.layers.dense( pre_Y, 2, name='{}phi2y'.format(j), reuse=tf.AUTO_REUSE )   
        pred_Y = tf.exp( log_pred_Y )
        ...........
        pre_cost += tf.multiply( ll, pred_Y )
    _cost += tf.reduce_sum(tf.multiply( pre_cost, pre_G), -1)

was obviously a very bad idea as it not only takes time but consumes a huge chunk of memory. For days I've been trying to reimplement using the tf.while_loop as I understand should be better than using the native python loop. I've been trying to implement a nested tf.while_loop but all to no avail. I have 2 questions:

  1. Is it a good idea to use a nested tf.while_loop? (ie have a tf.while_loop inside another to capture the computation of the nested FOR loop). If not, Is it possible to capture the computation of a nested FOR loop using the tf.while_loop?
  2. How can one use a tf.while_loop when you need to use different dense layers in each loop? eg in my case I have :

pre_Y  = self.build_nn( Gk, 200, '{}phi2'.format(j), tf.AUTO_REUSE ) 
log_pred_Y = tf.layers.dense( pre_Y, 2, name='{}phi2y'.format(j), reuse=tf.AUTO_REUSE )

which uses the index in the name to identify which of the layers to use in each loop. Using the python FOR loop, I use the index to assign a name to the layer to reuse but using the tf.while_loop, I can't find any way to assign a unique name to a dense layer as I can't pass a python string or integer to the tf.while_loop to use in the "name" option as I understand.

Re (1) nested tf.while_loop s work just fine.

Re (2), this kind of thing cannot be done with tf.while_loop and you'll need to unroll your loop using python.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM