简体   繁体   English

Tensorflow:使用argmax切片张量

[英]Tensorflow: using argmax to slice a tensor

I have a tensor with shape tf.shape(t1) = [1, 1000, 400] and I obtain the indices of the maxima on the 3rd dimension using max_ind = tf.argmax(t1, axis=-1) which has shape [1, 1000] . 我有一个形状为tf.shape(t1) = [1, 1000, 400]的张量,并使用max_ind = tf.argmax(t1, axis=-1)获得形状为[1, 1000] Now I have a second tensor that has the same shape as t1 : tf.shape(t2) = [1, 1000, 400] . 现在我有一个第二张量,其形状与t1相同: tf.shape(t2) = [1, 1000, 400]

I want to use the maxima indices from t1 to slice t2 so the output has the form 我想使用从t1到切片t2的最大值索引,因此输出形式为

[1, 1000]

A more visual description: The resulting tensor should be like the result of tf.reduce_max(t2, axis=-1) but with the location of the maxima in t1 更直观的描述:所得张量应类似于tf.reduce_max(t2, axis=-1)的结果,但最大值应位于t1

You can achieve this through tf.gather_nd , although it is not really straightforward. 您可以通过tf.gather_nd来实现这tf.gather_nd ,尽管它并不是很简单。 For example, 例如,

shape = t1.shape.as_list()
xy_ind = np.stack(np.mgrid[:shape[0], :shape[1]], axis=-1)
gather_ind = tf.concat([xy_ind, max_ind[..., None]], axis=-1)
sliced_t2 = tf.gather_nd(t2, gather_ind)

If on the other hand the shape of your input is unknown as graph construction time, you could use 另一方面,如果您不知道输入的形状是图形构建时间,则可以使用

shape = tf.shape(t1)
xy_ind = tf.stack(tf.meshgrid(tf.range(shape[0]), tf.range(shape[1]),
                              indexing='ij'), axis=-1)

and the remainder is the same as above. 其余与上述相同。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM