[英]Gather elements along second dimension of tensor
假設values
和張量T
都具有形狀(N,K)
。 現在,如果我們從矩陣的角度來考慮它們,我希望T
每一行都能獲得與索引相對應的行元素,其中values
具有最大值。 我可以很容易地找到那些指數
max_indicies = tf.argmax(T, 1)
它返回一個形狀(N)
的張量。 現在,我如何從T
收集這些指數,以便得到N
形狀? 我試過了
result = tf.gather(T,max_indices)
但它沒有做正確的事 - 它返回一些形狀(N,K)
,這意味着它沒有收集任何東西。
你可以使用tf.gather_nd 。
例如,
import tensorflow as tf
sess = tf.InteractiveSession()
values = tf.constant([[0, 0, 0, 1],
[0, 1, 0, 0],
[0, 0, 1, 0]])
T = tf.constant([[0, 1, 2 , 3],
[4, 5, 6 , 7],
[8, 9, 10, 11]])
max_indices = tf.argmax(values, axis=1)
# If T.get_shape()[0] is None, you can replace it with tf.shape(T)[0].
result = tf.gather_nd(T, tf.stack((tf.range(T.get_shape()[0],
dtype=max_indices.dtype),
max_indices),
axis=1))
print(result.eval())
但是當values
和T
的排名較高時,使用tf.gather_nd
會有點尷尬。 我在這個問題上發布了當前的解決方案。 在高維values
和T
情況下,可能有更好的解決方案。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.