简体   繁体   中英

tensor_scatter_nd_update ValueError: Shapes must be equal rank, but are 0 and 1

I've always been able to use tf.tensor_scatter_nd_update without any problems to write into tensors, but I can't manage to figure our why it's not working with some specific tensors.

As a simple example, say I want to set certain values in input=[[0 0 0]] to update=[[1 2 3]] , based on a boolean mask mask=[[1 0 1]] . I would simply do:

input=tf.tensor_scatter_nd_update(input,tf.where(mask),update)

expecting the result of the operation to be input=[[1 0 3]] .

Instead I'm getting

ValueError: Dimensions [2,2) of input[shape=[1,3]] = [] must match dimensions [1,2) of updates[shape=[1,3]] = [3]: Shapes must be equal rank, but are 0 and 1 for ... with input shapes: [1,3], [?,2], [1,3].

I really can't work out what's wrong; I've always been able to use the function without issue even in much more complex cases.

I figured it out.

Part of the problem is indeed that tf.where() returns a 2-D tensor, but this came into play because I was using it to also generate the updates vector:

input=input=tf.tensor_scatter_nd_update(input,tf.where(mask),tf.where(something_else))

The solution is to remove the extra dimension by:

input=input=tf.tensor_scatter_nd_update(input,tf.where(mask),tf.squeeze(tf.where(something_else)))

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM