I have an input of dimension (BATCH_SIZE*A*B*FEATURE_LENGTH)
. Now I want to select k(out of B) rows from each of A blocks from each input sample. The k values for each of the A blocks is different. For eg.
inp = ([[[[ 5, 38, 40, 13, 28],
[12, 6, 36, 20, 23],
[44, 35, 23, 46, 3]],
[[22, 32, 36, 20, 42],
[ 0, 19, 41, 36, 17],
[ 9, 35, 44, 7, 19]],
[[27, 10, 22, 10, 48],
[16, 42, 27, 7, 38],
[35, 32, 15, 39, 28]]]])
#size (1,3,3,5) = (1,A,B,FEATURE_LENGTH)
Now say k=2 ie I want to extract 2 rows from each of the 3 blocks. I want
row 0 and 1 from 1st block
row 1 and 2 from 2nd block
row 0 and 2 from 3rd block
That means I want my output to look like
([[[[ 5, 38, 40, 13, 28],
[12, 6, 36, 20, 23]],
[[ 0, 19, 41, 36, 17],
[ 9, 35, 44, 7, 19]],
[[27, 10, 22, 10, 48],
[35, 32, 15, 39, 28]]]])
#op shape = (1,3,2,5)
I found that using tf.gather_nd
this is possible if we provide indices as
ind = array([[[[0, 0, 0], [0, 0, 1]], [[0, 1, 1], [0, 1, 2]], [[0, 2, 0], [0, 2, 2]]]])
But if I have input of size (1,16,16,128)
and k=4
, creating this long index sequence will get tedious. Is there any simpler way to do it in Tensorflow-2? Thank you!
Use tf.gather()
with batch_dims
argument:
inds = tf.constant([[[0, 1], [1, 2], [0, 2]]])
output = tf.gather(inp, inds, batch_dims=2)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.