Suppose I have a tensor a
of shape [B,D]
, and I have a list I
containing indices of shape [B]
. Now I want to extend the tensor to the shape [M,D]
with M > B
using indices in the list. Note that the indices belong to the range [0,M]
. Concretely, I
is the mapping of rows from the tensor a
to another tensor that has larger value for dimension 0
. This functionality is opposite to the function tf.gather()
. Could someone suggest a solution? Thanks
tf.scatter_nd
is the reverse of tf.gather_nd
. Lets see that with a round trip example where:
((5,4,1,2,3))
where all elements are zero except elements at indices [1,2,0,0]
and [3,0,0,1]
, which are [16, 12, 11]
and [3,0,0,1]
respectively, using tf.scatter_nd
.tf.gather_nd
on the generated tensor to get the two original vectors, ie, [16, 12, 11]
and [3,0,0,1]
.updates = [[16, 12, 11],[18, 40, 37]]
indices = [[1,2,0,0], [3,0,0,1]]
shape = (5,4,1,2,3)
# first step - scatter
scat_tensor = tf.scatter_nd(indices=indices, updates=updates, shape=shape)
print(f"verification: expected {updates[0]}, got {scat_tensor[1,2,0,0]}")
# now the reverse step
reconstructed_updates = tf.gather_nd(scat_tensor, indices)
print(f"verification: expected {updates }, got {reconstructed_updates }")
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.