I've always been able to use tf.tensor_scatter_nd_update
without any problems to write into tensors, but I can't manage to figure our why it's not working with some specific tensors.
As a simple example, say I want to set certain values in input=[[0 0 0]]
to update=[[1 2 3]]
, based on a boolean mask mask=[[1 0 1]]
. I would simply do:
input=tf.tensor_scatter_nd_update(input,tf.where(mask),update)
expecting the result of the operation to be input=[[1 0 3]]
.
Instead I'm getting
ValueError: Dimensions [2,2) of input[shape=[1,3]] = [] must match dimensions [1,2) of updates[shape=[1,3]] = [3]: Shapes must be equal rank, but are 0 and 1 for ... with input shapes: [1,3], [?,2], [1,3].
I really can't work out what's wrong; I've always been able to use the function without issue even in much more complex cases.
I figured it out.
Part of the problem is indeed that tf.where()
returns a 2-D tensor, but this came into play because I was using it to also generate the updates
vector:
input=input=tf.tensor_scatter_nd_update(input,tf.where(mask),tf.where(something_else))
The solution is to remove the extra dimension by:
input=input=tf.tensor_scatter_nd_update(input,tf.where(mask),tf.squeeze(tf.where(something_else)))
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.