I'm trying to select elements differents from zero and work with them later. My input tensor has batch dimension, so I want to keep it and don't mix data over batches. I think tf.gather_nd()
would work for me, but first I have to get the indexes of the desired data and I found tf.where()
. I have tried the following:
img = tf.constant([[[1., 0., 0.],
[0., 0., 2.],
[0., 3, 0.]],
[[1., 2., 3.],
[0., 0., 1.],
[0., 0., 0.]]], dtype='float32') # shape [2, 3, 3]
indexes = tf.where(tf.not_equal(img, 0.))
I would expect indexes
to keep batch dimension, however it has shape [7, 2]
. I suspect the problem comes from having different number of points that satisfies the condition in different batches.
Is there a way to get the indexes keeping batch dimension? Thanks in advance.
EDIT: indexes
has shape [7, 3]
where first dim refers to number of points and the second dim refers to the position of the point (incluiding which batch it belongs to). But I need indexes
to have the specific batch dimension, because later I want to use it to ghater data from img
:
Y = tf.gather_nd(img, indexes)
I want Y
to have batch dimension, but as indexes
hasn't, I get a flat tensor with data from different bateches mixed.
Actually, you may have done something wrong : when I run your code, indexes
is of dimension (7,3)
and not (7,2)
. The 3
correspond to your 3 dimensions, whereas the 7
corresponds to the number of non-zero elements in img
.
Full result of sess.run(indexes)
:
array([[0, 0, 0],
[0, 1, 2],
[0, 2, 1],
[1, 0, 0],
[1, 0, 1],
[1, 0, 2],
[1, 1, 2]])
You may use tf.math.top_k()
to get values and indices with batch from inputs based on their values and then compute and apply mask to the values and indices.
img = tf.constant([[[1., 0., 0.],
[0., 0., 2.],
[0., 3, 0.]],
[[1., 2., 3.],
[0., 0., 1.],
[0., 0., 0.]]], dtype='float32') # shape [2, 3, 3]
values, indices = tf.math.top_k(img, k=3)
# values would be
# [[[1., 0., 0.], [2., 0., 0.], [3., 0., 0.]],
# [[3., 2., 1.], [1., 0., 0.], [0., 0., 0.]]]
# indices would be
# [[[0, 1, 2], [2, 0, 1], [1, 0, 2]],
# [[2, 1, 0], [2, 0, 1], [0, 1, 2]]]
mask = tf.cast(values, dtype=tf.bool)
# mask would be
# [[[True, False, False], [True, False, False], [True, False, False]],
# [[True, True, True], [True, False, False], [False, False, False]]]
Now you can get non-zero values of img
by using values
and mask
and also get non-zero indices of img
by using indices
and mask
. And you can use tf.gather()
to get values from img
and indices
as like:
values2 = tf.gather(img, indices, batch_dims=2)
# values2 will be same with the above values
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.