简体   繁体   中英

Define a tensor of constants for tf.while_loop

I want to somehow maintain a list of constants in tf.while_loop that can support the following functions

  1. I am able to read and write (multiple times) a constant value at an index
  2. I am able to run tf.cond on it by checking its value at an index vs some constant

TensorArray would not work here since it does not support rewrites. What other options do I have?

You could just define a normal Tensor and update it with tf.tensor_scatter_nd_update :

%tensorflow_version 1.x

import tensorflow as tf

data = tf.constant([1, 1, 1, 0, 1, 0, 1, 1, 0, 0], dtype=tf.float32)
data_tensor = tf.zeros_like(data)
tensor_size = data_tensor.shape[0]

init_state = (0, data_tensor)
condition = lambda i, _: i < tensor_size

def custom_body(i, tensor):
  special_index = 3 # index for which a value should be changed
  new_value = 8
  tensor = tf.where(tf.equal(i, special_index), 
                    tf.tensor_scatter_nd_update(tensor, [[special_index]], [new_value]),
                    tf.tensor_scatter_nd_update(tensor, [[i]], [data[i]*2]))

  return i + 1, tensor


body = lambda i, tensor: (custom_body(i, tensor))
_, final_result = tf.while_loop(condition, body, init_state)

with tf.Session() as sess:
  final_result_values = final_result.eval()

print(final_result_values)
[2. 2. 2. 8. 2. 0. 2. 2. 0. 0.]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM