[英]Different behavior of gather() function as seen in tensorflow and pytorch
我有一個張量形狀(16, 4096, 3)
。 我還有一個形狀為(16, 32768, 3)
的索引張量。 我正在嘗試沿dim=1
收集值。 這是最初在pytorch使用進行gather
功能,如圖所示如下─
# a.shape (16L, 4096L, 3L)
# idx.shape (16L, 32768L, 3L)
b = a.gather(1, idx)
# b.shape (16L, 32768L, 3L)
請注意,輸出b
的大小與idx
的大小相同。 但是,當我應用tensorflow的gather
函數時,會得到完全不同的輸出。 發現輸出尺寸不匹配,如下所示-
b = tf.gather(a, idx, axis=1)
# b.shape (16, 16, 32768, 3, 3)
我也嘗試使用tf.gather_nd
但徒勞無功。 見下文-
b = tf.gather_nd(a, idx)
# b.shape (16, 32768)
為什么我得到不同形狀的張量? 我想獲得與pytorch計算的形狀相同的張量。
如何獲得與pytorch相同的結果?
如果我對您的理解正確,那么tf.gather_nd
是您所需要的。 如果沒有,請更加清楚。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.