[英]Tensorflow: how to keep batch dimension when using tf.where()?
我正在尝试 select 元素不同于零,然后再使用它们。 我的输入张量具有批次维度,因此我想保留它并且不要将数据混合到批次中。 我认为tf.gather_nd()
对我有用,但首先我必须获取所需数据的索引,然后我找到tf.where()
。 我尝试了以下方法:
img = tf.constant([[[1., 0., 0.],
[0., 0., 2.],
[0., 3, 0.]],
[[1., 2., 3.],
[0., 0., 1.],
[0., 0., 0.]]], dtype='float32') # shape [2, 3, 3]
indexes = tf.where(tf.not_equal(img, 0.))
我希望indexes
保持批量维度,但是它的形状为[7, 2]
。 我怀疑问题出在不同批次中满足条件的点数不同。
有没有办法让索引保持批量维度? 提前致谢。
编辑: indexes
的形状为[7, 3]
,其中第一个暗淡指的是点数,第二个暗淡指的是该点的 position(包括它属于哪个批次)。 但是我需要indexes
具有特定的批处理维度,因为稍后我想用它来收集来自img
的数据:
Y = tf.gather_nd(img, indexes)
我希望Y
具有批次维度,但由于indexes
没有,我得到一个平坦的张量,其中混合了来自不同批次的数据。
实际上,您可能做错了什么:运行代码时, indexes
的尺寸为(7,3)
而不是(7,2)
。 3
对应于3维,而7
对应于img
非零元素的数量。
sess.run(indexes)
完整结果:
array([[0, 0, 0],
[0, 1, 2],
[0, 2, 1],
[1, 0, 0],
[1, 0, 1],
[1, 0, 2],
[1, 1, 2]])
您可以使用tf.math.top_k()
根据其值从输入中批量获取值和索引,然后计算掩码并将其应用于值和索引。
img = tf.constant([[[1., 0., 0.],
[0., 0., 2.],
[0., 3, 0.]],
[[1., 2., 3.],
[0., 0., 1.],
[0., 0., 0.]]], dtype='float32') # shape [2, 3, 3]
values, indices = tf.math.top_k(img, k=3)
# values would be
# [[[1., 0., 0.], [2., 0., 0.], [3., 0., 0.]],
# [[3., 2., 1.], [1., 0., 0.], [0., 0., 0.]]]
# indices would be
# [[[0, 1, 2], [2, 0, 1], [1, 0, 2]],
# [[2, 1, 0], [2, 0, 1], [0, 1, 2]]]
mask = tf.cast(values, dtype=tf.bool)
# mask would be
# [[[True, False, False], [True, False, False], [True, False, False]],
# [[True, True, True], [True, False, False], [False, False, False]]]
现在,您可以使用values
和mask
获得img
的非零值,也可以使用indices
和mask
获得img
的非零索引。 您可以使用tf.gather()
从img
和indices
获取值,如下所示:
values2 = tf.gather(img, indices, batch_dims=2)
# values2 will be same with the above values
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.