繁体   English   中英

Tensorflow:使用 tf.where() 时如何保持批量维度?

[英]Tensorflow: how to keep batch dimension when using tf.where()?

我正在尝试 select 元素不同于零,然后再使用它们。 我的输入张量具有批次维度,因此我想保留它并且不要将数据混合到批次中。 我认为tf.gather_nd()对我有用,但首先我必须获取所需数据的索引,然后我找到tf.where() 我尝试了以下方法:

img = tf.constant([[[1., 0., 0.], 
                    [0., 0., 2.],
                    [0., 3, 0.]], 
                   [[1., 2., 3.], 
                    [0., 0., 1.], 
                    [0., 0., 0.]]], dtype='float32') # shape [2, 3, 3]

indexes = tf.where(tf.not_equal(img, 0.))

我希望indexes保持批量维度,但是它的形状为[7, 2] 我怀疑问题出在不同批次中满足条件的点数不同。

有没有办法让索引保持批量维度? 提前致谢。

编辑: indexes的形状为[7, 3] ,其中第一个暗淡指的是点数,第二个暗淡指的是该点的 position(包括它属于哪个批次)。 但是我需要indexes具有特定的批处理维度,因为稍后我想用它来收集来自img的数据:

Y = tf.gather_nd(img, indexes)

我希望Y具有批次维度,但由于indexes没有,我得到一个平坦的张量,其中混合了来自不同批次的数据。

实际上,您可能做错了什么:运行代码时, indexes的尺寸为(7,3)而不是(7,2) 3对应于3维,而7对应于img非零元素的数量。

sess.run(indexes)完整结果:

array([[0, 0, 0],
      [0, 1, 2],
      [0, 2, 1],
      [1, 0, 0],
      [1, 0, 1],
      [1, 0, 2],
      [1, 1, 2]])

您可以使用tf.math.top_k()根据其值从输入中批量获取值和索引,然后计算掩码并将其应用于值和索引。

img = tf.constant([[[1., 0., 0.], 
                    [0., 0., 2.],
                    [0., 3, 0.]], 
                   [[1., 2., 3.], 
                    [0., 0., 1.], 
                    [0., 0., 0.]]], dtype='float32') # shape [2, 3, 3]

values, indices = tf.math.top_k(img, k=3)
# values would be 
# [[[1., 0., 0.], [2., 0., 0.], [3., 0., 0.]],
#  [[3., 2., 1.], [1., 0., 0.], [0., 0., 0.]]]
# indices would be
# [[[0, 1, 2], [2, 0, 1], [1, 0, 2]],
#  [[2, 1, 0], [2, 0, 1], [0, 1, 2]]]

mask = tf.cast(values, dtype=tf.bool)
# mask would be
# [[[True, False, False], [True, False, False], [True, False, False]],
#  [[True, True, True], [True, False, False], [False, False, False]]]

现在,您可以使用valuesmask获得img的非零值,也可以使用indicesmask获得img的非零索引。 您可以使用tf.gather()imgindices获取值,如下所示:

values2 = tf.gather(img, indices, batch_dims=2)
# values2 will be same with the above values

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM