简体   繁体   中英

Inverse of function tf.gather()

Suppose I have a tensor a of shape [B,D] , and I have a list I containing indices of shape [B] . Now I want to extend the tensor to the shape [M,D] with M > B using indices in the list. Note that the indices belong to the range [0,M] . Concretely, I is the mapping of rows from the tensor a to another tensor that has larger value for dimension 0 . This functionality is opposite to the function tf.gather() . Could someone suggest a solution? Thanks

tf.scatter_nd is the reverse of tf.gather_nd . Lets see that with a round trip example where:

  • First we create a tensor of shape ((5,4,1,2,3)) where all elements are zero except elements at indices [1,2,0,0] and [3,0,0,1] , which are [16, 12, 11] and [3,0,0,1] respectively, using tf.scatter_nd .
  • Second, we do the reverse by applying the tf.gather_nd on the generated tensor to get the two original vectors, ie, [16, 12, 11] and [3,0,0,1] .
updates = [[16, 12, 11],[18, 40, 37]]
indices = [[1,2,0,0], [3,0,0,1]]
shape = (5,4,1,2,3)

# first step - scatter
scat_tensor = tf.scatter_nd(indices=indices, updates=updates, shape=shape)
print(f"verification: expected {updates[0]}, got {scat_tensor[1,2,0,0]}")

# now the reverse step
reconstructed_updates = tf.gather_nd(scat_tensor, indices)
print(f"verification: expected {updates }, got {reconstructed_updates }")

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM