简体   繁体   中英

Correct use of tf.while_loop when variables are created inside body

I'm using a while_loop in Tensorflow in order to iterate over a tensor and extracting specific slices over a given dimension. For each step, I need to use a decoder RNN to generate a sequence of output symbols. I'm using the code provided in tf.contrib.seq2seq , in particular, tf.contrib.seq2seq.dynamic_decode . The code looks similar to the following:

def decoder_condition(i, data, source_seq_len, ta_outputs):
    return tf.less(i, max_loop_len)

def decode_body(i, data, source_seq_len, ta_outputs):
    curr_data = data[:, i, :]
    curr_source_seq_len = source_seq_len[:, i, :]
    attention_mechanism = tf.contrib.seq2seq.LuongAttention(
        2 * self.opt["encoder_rnn_h_size"],
        curr_data,
        memory_sequence_length=curr_source_seq_len
    )
    cell = GRUCell(num_units)
    cell = AttentionWrapper(cell, attention_mechanism)
    # ... other code that initialises all the variables required
    # for the RNN decoder
    outputs = tf.contrib.seq2seq.dynamic_decode(
        decoder,
        maximum_iterations=self.opt["max_sys_seq_len"],
        swap_memory=True
    )
    with tf.control_dependencies([outputs)]:
        ta_outputs = ta_outputs.write(i, outputs)

    return i+1, data, ta_outputs

 loop_index = tf.constant(0)
 gen_outputs = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
 outputs = tf.while_loop(
      decoder_condition,
      decoder_body,
      loop_vars=[
          loop_index,
          data,
          data_source_len,
          ta_outputs
      ],
      swap_memory=True,
      back_prop=True, 
      parallel_iterations=1
)

So as you can see, I create different objects which depend specifically on the input at the current step i . I'm using tf.AUTO_REUSE in my current variable scope in such a way that the variables are reused even if I'm creating different objects. Unfortunately, my decoder seems that it's not properly training because it keeps generating incorrect values. I've already checked the input data to the decoder RNN and everything is correct. I suspect that there is something that I'm not doing properly in terms of how TensorFlow manages the TensorArray and while_loop.

So my main questions are:

  1. Is TensorFlow correctly propagating the gradients for each variable that it's created inside the while loop?
  2. Is it possible to create object inside the while loop that are dependent on specific slices of a Tensor obtained using the loop index?
  3. Does the backprop parameter guarantee that the gradients are propagated during training? Should it be set to False during inference?
  4. In general, are there any sanity check that I can use to spot possible errors in my implementation?

Thanks!

UPDATE: Not sure why but seems that there is an open issue about this which is related to the possibility to invoke custom operations in a while loop as explained here: https://github.com/tensorflow/tensorflow/issues/13616 . Unfortunately, I don't know enough TensorFlow's internals to judge if it's completely related to this.

UPDATE 2: I solved using PyTorch :)

(1) Yes

(2) Yes, just slice a tensor using the loop index

(3) No need to set backprop=False in ordinary use-cases

(4) The usual things you do with ML models (toy datasets, test parts separately, etc)

Re update 2, try using eager execution or tf.contrib.autograph; both should let you write the while loop in plain python.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM