简体   繁体   中英

TypeError: List of Tensors when single Tensor expected due to tensor_scatter_update

Take a look at the following code sample:

def myFun(my_tensor):
        #The following line works
        my_tensor= tf.tensor_scatter_update(my_tensor, tf.constant([[0]]), tf.constant([1]))
        #The following line leads to error
        p = tf.cond(tf.math.equal(0, 0), lambda: 1, lambda: 1)
        my_tensor= tf.tensor_scatter_update(my_tensor, tf.constant([[p]]), tf.constant([1]))

I have taken a simple case to describe the issue I am facing This function (myFun) is called as the body of a tf.while_loop (in case that is relevant) Definiton of my_tensor

my_tensor = tf.zeros(5, tf.int32)

How do I define the indices parameter of the tf.tensor_scatter_update? I am using tensorflow1.15

You cannot use the tensor p as an argument for tf.constant . Maybe try something like this:

%tensorflow_version 1.x
import tensorflow as tf

def myFun(my_tensor):

    my_tensor= tf.tensor_scatter_update(my_tensor, tf.constant([[0]]), tf.constant([1]))
    p = tf.cond(tf.math.equal(0, 0), lambda: 1, lambda: 1)
    new_tensor= tf.tensor_scatter_update(my_tensor, [[p]], tf.constant([1]))

    with tf.Session() as sess:
      p_value = p.eval()
      tensor_values = my_tensor.eval()
      new_tensor_values = new_tensor.eval()

    print('p -->', p_value)
    print('my_tensor -->', tensor_values)
    print('new_tensor -->', new_tensor_values)

my_tensor = tf.zeros(5, tf.int32)
myFun(my_tensor)
p --> 1
my_tensor --> [1 0 0 0 0]
new_tensor --> [1 1 0 0 0]

You can also wrap p around a tf.Variable :

def myFun(my_tensor):

    my_tensor= tf.tensor_scatter_update(my_tensor, tf.constant([[0]]), tf.constant([1]))
    p = tf.cond(tf.math.equal(0, 0), lambda: 1, lambda: 1)

    indices = tf.Variable([[p]])       
    new_tensor= tf.tensor_scatter_update(my_tensor, indices, tf.constant([1]))

    with tf.Session() as sess:
      sess.run(indices.initializer)
      p_value = p.eval()
      tensor_values = my_tensor.eval()
      new_tensor_values = new_tensor.eval()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM