簡體   English   中英

在tensorflow和pytorch中看到的collect()函數的不同行為

[英]Different behavior of gather() function as seen in tensorflow and pytorch

我有一個張量形狀(16, 4096, 3) 我還有一個形狀為(16, 32768, 3)的索引張量。 我正在嘗試沿dim=1收集值。 這是最初在pytorch使用進行gather功能,如圖所示如下─

# a.shape (16L, 4096L, 3L)
# idx.shape (16L, 32768L, 3L)
b = a.gather(1, idx)
# b.shape (16L, 32768L, 3L)

請注意,輸出b的大小與idx的大小相同。 但是,當我應用tensorflow的gather函數時,會得到完全不同的輸出。 發現輸出尺寸不匹配,如下所示-

b = tf.gather(a, idx, axis=1)
# b.shape (16, 16, 32768, 3, 3)

我也嘗試使用tf.gather_nd但徒勞無功。 見下文-

b = tf.gather_nd(a, idx)
# b.shape (16, 32768)

為什么我得到不同形狀的張量? 我想獲得與pytorch計算的形狀相同的張量。

如何獲得與pytorch相同的結果?

如果我對您的理解正確,那么tf.gather_nd是您所需要的。 如果沒有,請更加清楚。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM