简体   繁体   中英

Tensorflow: how to keep batch dimension when using tf.where()?

I'm trying to select elements differents from zero and work with them later. My input tensor has batch dimension, so I want to keep it and don't mix data over batches. I think tf.gather_nd() would work for me, but first I have to get the indexes of the desired data and I found tf.where() . I have tried the following:

img = tf.constant([[[1., 0., 0.], 
                    [0., 0., 2.],
                    [0., 3, 0.]], 
                   [[1., 2., 3.], 
                    [0., 0., 1.], 
                    [0., 0., 0.]]], dtype='float32') # shape [2, 3, 3]

indexes = tf.where(tf.not_equal(img, 0.))

I would expect indexes to keep batch dimension, however it has shape [7, 2] . I suspect the problem comes from having different number of points that satisfies the condition in different batches.

Is there a way to get the indexes keeping batch dimension? Thanks in advance.

EDIT: indexes has shape [7, 3] where first dim refers to number of points and the second dim refers to the position of the point (incluiding which batch it belongs to). But I need indexes to have the specific batch dimension, because later I want to use it to ghater data from img :

Y = tf.gather_nd(img, indexes)

I want Y to have batch dimension, but as indexes hasn't, I get a flat tensor with data from different bateches mixed.

Actually, you may have done something wrong : when I run your code, indexes is of dimension (7,3) and not (7,2) . The 3 correspond to your 3 dimensions, whereas the 7 corresponds to the number of non-zero elements in img .

Full result of sess.run(indexes) :

array([[0, 0, 0],
      [0, 1, 2],
      [0, 2, 1],
      [1, 0, 0],
      [1, 0, 1],
      [1, 0, 2],
      [1, 1, 2]])

You may use tf.math.top_k() to get values and indices with batch from inputs based on their values and then compute and apply mask to the values and indices.

img = tf.constant([[[1., 0., 0.], 
                    [0., 0., 2.],
                    [0., 3, 0.]], 
                   [[1., 2., 3.], 
                    [0., 0., 1.], 
                    [0., 0., 0.]]], dtype='float32') # shape [2, 3, 3]

values, indices = tf.math.top_k(img, k=3)
# values would be 
# [[[1., 0., 0.], [2., 0., 0.], [3., 0., 0.]],
#  [[3., 2., 1.], [1., 0., 0.], [0., 0., 0.]]]
# indices would be
# [[[0, 1, 2], [2, 0, 1], [1, 0, 2]],
#  [[2, 1, 0], [2, 0, 1], [0, 1, 2]]]

mask = tf.cast(values, dtype=tf.bool)
# mask would be
# [[[True, False, False], [True, False, False], [True, False, False]],
#  [[True, True, True], [True, False, False], [False, False, False]]]

Now you can get non-zero values of img by using values and mask and also get non-zero indices of img by using indices and mask . And you can use tf.gather() to get values from img and indices as like:

values2 = tf.gather(img, indices, batch_dims=2)
# values2 will be same with the above values

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM