简体   繁体   English

Tensorflow:使用 tf.where() 时如何保持批量维度?

[英]Tensorflow: how to keep batch dimension when using tf.where()?

I'm trying to select elements differents from zero and work with them later.我正在尝试 select 元素不同于零,然后再使用它们。 My input tensor has batch dimension, so I want to keep it and don't mix data over batches.我的输入张量具有批次维度,因此我想保留它并且不要将数据混合到批次中。 I think tf.gather_nd() would work for me, but first I have to get the indexes of the desired data and I found tf.where() .我认为tf.gather_nd()对我有用,但首先我必须获取所需数据的索引,然后我找到tf.where() I have tried the following:我尝试了以下方法:

img = tf.constant([[[1., 0., 0.], 
                    [0., 0., 2.],
                    [0., 3, 0.]], 
                   [[1., 2., 3.], 
                    [0., 0., 1.], 
                    [0., 0., 0.]]], dtype='float32') # shape [2, 3, 3]

indexes = tf.where(tf.not_equal(img, 0.))

I would expect indexes to keep batch dimension, however it has shape [7, 2] .我希望indexes保持批量维度,但是它的形状为[7, 2] I suspect the problem comes from having different number of points that satisfies the condition in different batches.我怀疑问题出在不同批次中满足条件的点数不同。

Is there a way to get the indexes keeping batch dimension?有没有办法让索引保持批量维度? Thanks in advance.提前致谢。

EDIT: indexes has shape [7, 3] where first dim refers to number of points and the second dim refers to the position of the point (incluiding which batch it belongs to).编辑: indexes的形状为[7, 3] ,其中第一个暗淡指的是点数,第二个暗淡指的是该点的 position(包括它属于哪个批次)。 But I need indexes to have the specific batch dimension, because later I want to use it to ghater data from img :但是我需要indexes具有特定的批处理维度,因为稍后我想用它来收集来自img的数据:

Y = tf.gather_nd(img, indexes)

I want Y to have batch dimension, but as indexes hasn't, I get a flat tensor with data from different bateches mixed.我希望Y具有批次维度,但由于indexes没有,我得到一个平坦的张量,其中混合了来自不同批次的数据。

Actually, you may have done something wrong : when I run your code, indexes is of dimension (7,3) and not (7,2) . 实际上,您可能做错了什么:运行代码时, indexes的尺寸为(7,3)而不是(7,2) The 3 correspond to your 3 dimensions, whereas the 7 corresponds to the number of non-zero elements in img . 3对应于3维,而7对应于img非零元素的数量。

Full result of sess.run(indexes) : sess.run(indexes)完整结果:

array([[0, 0, 0],
      [0, 1, 2],
      [0, 2, 1],
      [1, 0, 0],
      [1, 0, 1],
      [1, 0, 2],
      [1, 1, 2]])

You may use tf.math.top_k() to get values and indices with batch from inputs based on their values and then compute and apply mask to the values and indices.您可以使用tf.math.top_k()根据其值从输入中批量获取值和索引,然后计算掩码并将其应用于值和索引。

img = tf.constant([[[1., 0., 0.], 
                    [0., 0., 2.],
                    [0., 3, 0.]], 
                   [[1., 2., 3.], 
                    [0., 0., 1.], 
                    [0., 0., 0.]]], dtype='float32') # shape [2, 3, 3]

values, indices = tf.math.top_k(img, k=3)
# values would be 
# [[[1., 0., 0.], [2., 0., 0.], [3., 0., 0.]],
#  [[3., 2., 1.], [1., 0., 0.], [0., 0., 0.]]]
# indices would be
# [[[0, 1, 2], [2, 0, 1], [1, 0, 2]],
#  [[2, 1, 0], [2, 0, 1], [0, 1, 2]]]

mask = tf.cast(values, dtype=tf.bool)
# mask would be
# [[[True, False, False], [True, False, False], [True, False, False]],
#  [[True, True, True], [True, False, False], [False, False, False]]]

Now you can get non-zero values of img by using values and mask and also get non-zero indices of img by using indices and mask .现在,您可以使用valuesmask获得img的非零值,也可以使用indicesmask获得img的非零索引。 And you can use tf.gather() to get values from img and indices as like:您可以使用tf.gather()imgindices获取值,如下所示:

values2 = tf.gather(img, indices, batch_dims=2)
# values2 will be same with the above values

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM