簡體   English   中英

沿張量的二維收集元素

[英]Gather elements along second dimension of tensor

假設values和張量T都具有形狀(N,K) 現在,如果我們從矩陣的角度來考慮它們,我希望T每一行都能獲得與索引相對應的行元素,其中values具有最大值。 我可以很容易地找到那些指數

max_indicies = tf.argmax(T, 1)

它返回一個形狀(N)的張量。 現在,我如何從T收集這些指數,以便得到N形狀? 我試過了

result = tf.gather(T,max_indices)

但它沒有做正確的事 - 它返回一些形狀(N,K) ,這意味着它沒有收集任何東西。

你可以使用tf.gather_nd

例如,

import tensorflow as tf

sess = tf.InteractiveSession()

values = tf.constant([[0, 0, 0, 1],
                      [0, 1, 0, 0],
                      [0, 0, 1, 0]])

T = tf.constant([[0, 1, 2 ,  3],
                 [4, 5, 6 ,  7],
                 [8, 9, 10, 11]])

max_indices = tf.argmax(values, axis=1)
# If T.get_shape()[0] is None, you can replace it with tf.shape(T)[0].
result = tf.gather_nd(T, tf.stack((tf.range(T.get_shape()[0], 
                                            dtype=max_indices.dtype),
                                   max_indices),
                                  axis=1))

print(result.eval())

但是當valuesT的排名較高時,使用tf.gather_nd會有點尷尬。 我在這個問題上發布了當前的解決方案。 在高維valuesT情況下,可能有更好的解決方案。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM