[英]Tensorflow: how to keep batch dimension when using tf.where()?
I'm trying to select elements differents from zero and work with them later.我正在尝试 select 元素不同于零,然后再使用它们。 My input tensor has batch dimension, so I want to keep it and don't mix data over batches.
我的输入张量具有批次维度,因此我想保留它并且不要将数据混合到批次中。 I think
tf.gather_nd()
would work for me, but first I have to get the indexes of the desired data and I found tf.where()
.我认为
tf.gather_nd()
对我有用,但首先我必须获取所需数据的索引,然后我找到tf.where()
。 I have tried the following:我尝试了以下方法:
img = tf.constant([[[1., 0., 0.],
[0., 0., 2.],
[0., 3, 0.]],
[[1., 2., 3.],
[0., 0., 1.],
[0., 0., 0.]]], dtype='float32') # shape [2, 3, 3]
indexes = tf.where(tf.not_equal(img, 0.))
I would expect indexes
to keep batch dimension, however it has shape [7, 2]
.我希望
indexes
保持批量维度,但是它的形状为[7, 2]
。 I suspect the problem comes from having different number of points that satisfies the condition in different batches.我怀疑问题出在不同批次中满足条件的点数不同。
Is there a way to get the indexes keeping batch dimension?有没有办法让索引保持批量维度? Thanks in advance.
提前致谢。
EDIT: indexes
has shape [7, 3]
where first dim refers to number of points and the second dim refers to the position of the point (incluiding which batch it belongs to).编辑:
indexes
的形状为[7, 3]
,其中第一个暗淡指的是点数,第二个暗淡指的是该点的 position(包括它属于哪个批次)。 But I need indexes
to have the specific batch dimension, because later I want to use it to ghater data from img
:但是我需要
indexes
具有特定的批处理维度,因为稍后我想用它来收集来自img
的数据:
Y = tf.gather_nd(img, indexes)
I want Y
to have batch dimension, but as indexes
hasn't, I get a flat tensor with data from different bateches mixed.我希望
Y
具有批次维度,但由于indexes
没有,我得到一个平坦的张量,其中混合了来自不同批次的数据。
Actually, you may have done something wrong : when I run your code, indexes
is of dimension (7,3)
and not (7,2)
. 实际上,您可能做错了什么:运行代码时,
indexes
的尺寸为(7,3)
而不是(7,2)
。 The 3
correspond to your 3 dimensions, whereas the 7
corresponds to the number of non-zero elements in img
. 3
对应于3维,而7
对应于img
非零元素的数量。
Full result of sess.run(indexes)
: sess.run(indexes)
完整结果:
array([[0, 0, 0],
[0, 1, 2],
[0, 2, 1],
[1, 0, 0],
[1, 0, 1],
[1, 0, 2],
[1, 1, 2]])
You may use tf.math.top_k()
to get values and indices with batch from inputs based on their values and then compute and apply mask to the values and indices.您可以使用
tf.math.top_k()
根据其值从输入中批量获取值和索引,然后计算掩码并将其应用于值和索引。
img = tf.constant([[[1., 0., 0.],
[0., 0., 2.],
[0., 3, 0.]],
[[1., 2., 3.],
[0., 0., 1.],
[0., 0., 0.]]], dtype='float32') # shape [2, 3, 3]
values, indices = tf.math.top_k(img, k=3)
# values would be
# [[[1., 0., 0.], [2., 0., 0.], [3., 0., 0.]],
# [[3., 2., 1.], [1., 0., 0.], [0., 0., 0.]]]
# indices would be
# [[[0, 1, 2], [2, 0, 1], [1, 0, 2]],
# [[2, 1, 0], [2, 0, 1], [0, 1, 2]]]
mask = tf.cast(values, dtype=tf.bool)
# mask would be
# [[[True, False, False], [True, False, False], [True, False, False]],
# [[True, True, True], [True, False, False], [False, False, False]]]
Now you can get non-zero values of img
by using values
and mask
and also get non-zero indices of img
by using indices
and mask
.现在,您可以使用
values
和mask
获得img
的非零值,也可以使用indices
和mask
获得img
的非零索引。 And you can use tf.gather()
to get values from img
and indices
as like:您可以使用
tf.gather()
从img
和indices
获取值,如下所示:
values2 = tf.gather(img, indices, batch_dims=2)
# values2 will be same with the above values
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.