I want to somehow maintain a list of constants in tf.while_loop
that can support the following functions
tf.cond
on it by checking its value at an index vs some constant TensorArray
would not work here since it does not support rewrites. What other options do I have?
You could just define a normal Tensor
and update it with tf.tensor_scatter_nd_update
:
%tensorflow_version 1.x
import tensorflow as tf
data = tf.constant([1, 1, 1, 0, 1, 0, 1, 1, 0, 0], dtype=tf.float32)
data_tensor = tf.zeros_like(data)
tensor_size = data_tensor.shape[0]
init_state = (0, data_tensor)
condition = lambda i, _: i < tensor_size
def custom_body(i, tensor):
special_index = 3 # index for which a value should be changed
new_value = 8
tensor = tf.where(tf.equal(i, special_index),
tf.tensor_scatter_nd_update(tensor, [[special_index]], [new_value]),
tf.tensor_scatter_nd_update(tensor, [[i]], [data[i]*2]))
return i + 1, tensor
body = lambda i, tensor: (custom_body(i, tensor))
_, final_result = tf.while_loop(condition, body, init_state)
with tf.Session() as sess:
final_result_values = final_result.eval()
print(final_result_values)
[2. 2. 2. 8. 2. 0. 2. 2. 0. 0.]
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.