简体   繁体   中英

Using tf.gather or tf.gather_nd

I have an input of dimension (BATCH_SIZE*A*B*FEATURE_LENGTH) . Now I want to select k(out of B) rows from each of A blocks from each input sample. The k values for each of the A blocks is different. For eg.

inp = ([[[[ 5, 38, 40, 13, 28],
         [12,  6, 36, 20, 23],
         [44, 35, 23, 46,  3]],

        [[22, 32, 36, 20, 42],
         [ 0, 19, 41, 36, 17],
         [ 9, 35, 44,  7, 19]],

        [[27, 10, 22, 10, 48],
         [16, 42, 27,  7, 38],
         [35, 32, 15, 39, 28]]]])
#size (1,3,3,5) = (1,A,B,FEATURE_LENGTH)

Now say k=2 ie I want to extract 2 rows from each of the 3 blocks. I want

row 0 and 1 from 1st block
row 1 and 2 from 2nd block
row 0 and 2 from 3rd block

That means I want my output to look like

([[[[ 5, 38, 40, 13, 28],
    [12,  6, 36, 20, 23]],
   
    [[ 0, 19, 41, 36, 17],
     [ 9, 35, 44,  7, 19]],

    [[27, 10, 22, 10, 48],
     [35, 32, 15, 39, 28]]]])
#op shape = (1,3,2,5)

I found that using tf.gather_nd this is possible if we provide indices as

ind = array([[[[0, 0, 0], [0, 0, 1]], [[0, 1, 1], [0, 1, 2]], [[0, 2, 0], [0, 2, 2]]]])

But if I have input of size (1,16,16,128) and k=4 , creating this long index sequence will get tedious. Is there any simpler way to do it in Tensorflow-2? Thank you!

Use tf.gather() with batch_dims argument:

inds = tf.constant([[[0, 1], [1, 2], [0, 2]]])
output = tf.gather(inp, inds, batch_dims=2)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM