I'm using a while_loop in Tensorflow in order to iterate over a tensor and extracting specific slices over a given dimension. For each step, I need to use a decoder RNN to generate a sequence of output symbols. I'm using the code provided in tf.contrib.seq2seq , in particular, tf.contrib.seq2seq.dynamic_decode . The code looks similar to the following:
def decoder_condition(i, data, source_seq_len, ta_outputs):
return tf.less(i, max_loop_len)
def decode_body(i, data, source_seq_len, ta_outputs):
curr_data = data[:, i, :]
curr_source_seq_len = source_seq_len[:, i, :]
attention_mechanism = tf.contrib.seq2seq.LuongAttention(
2 * self.opt["encoder_rnn_h_size"],
curr_data,
memory_sequence_length=curr_source_seq_len
)
cell = GRUCell(num_units)
cell = AttentionWrapper(cell, attention_mechanism)
# ... other code that initialises all the variables required
# for the RNN decoder
outputs = tf.contrib.seq2seq.dynamic_decode(
decoder,
maximum_iterations=self.opt["max_sys_seq_len"],
swap_memory=True
)
with tf.control_dependencies([outputs)]:
ta_outputs = ta_outputs.write(i, outputs)
return i+1, data, ta_outputs
loop_index = tf.constant(0)
gen_outputs = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
outputs = tf.while_loop(
decoder_condition,
decoder_body,
loop_vars=[
loop_index,
data,
data_source_len,
ta_outputs
],
swap_memory=True,
back_prop=True,
parallel_iterations=1
)
So as you can see, I create different objects which depend specifically on the input at the current step i . I'm using tf.AUTO_REUSE
in my current variable scope in such a way that the variables are reused even if I'm creating different objects. Unfortunately, my decoder seems that it's not properly training because it keeps generating incorrect values. I've already checked the input data to the decoder RNN and everything is correct. I suspect that there is something that I'm not doing properly in terms of how TensorFlow manages the TensorArray and while_loop.
So my main questions are:
Thanks!
UPDATE: Not sure why but seems that there is an open issue about this which is related to the possibility to invoke custom operations in a while loop as explained here: https://github.com/tensorflow/tensorflow/issues/13616 . Unfortunately, I don't know enough TensorFlow's internals to judge if it's completely related to this.
UPDATE 2: I solved using PyTorch :)
(1) Yes
(2) Yes, just slice a tensor using the loop index
(3) No need to set backprop=False in ordinary use-cases
(4) The usual things you do with ML models (toy datasets, test parts separately, etc)
Re update 2, try using eager execution or tf.contrib.autograph; both should let you write the while loop in plain python.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.