簡體   English   中英

在 tf.nn.top_k 中加入 torch.topk 的 dim 參數

[英]Incorporating dim parameter of torch.topk in tf.nn.top_k

Pytorch 提供torch.topk(input, k, dim=None, largest=True, sorted=True)函數來計算給定input張量沿給定維度dim k最大元素。

我有一個形狀(16, 512, 4096)的張量,我以下列方式使用torch.topk

# inputs.shape (16L, 512L, 4096L)
dist, idx = torch.topk(inputs, 64, dim=2, largest=False, sorted=False)
# dist.shape (16L, 512L, 64L), idx.shape (16L, 512L, 64L)

我發現類似的 tensorflow 實現如下 - tf.nn.top_k(input, k=1, sorted=True, name=None)

我的問題是如何在tf.nn.top_ktf.nn.top_k dim=2參數,從而實現與 pytorch 計算出的形狀相同的張量?

tf.nn.top_k處理輸入的最后一個維度。 這意味着它應該像您的示例一樣工作:

dist, idx = tf.nn.top_k(inputs, 64, sorted=False)

一般來說,你可以想象 Tensorflow 版本像 Pytorch 版本一樣工作,硬編碼dim=-1 ,即最后一個維度。

但是,看起來您實際上想要 k 個最小的元素。 在這種情況下,我們可以做

dist, idx = tf.nn.top_k(-1*inputs, 64, sorted=False)
dist = -1*dist

所以我們取k個最大的輸入,也就是原始輸入中k個最小的。 然后我們反轉值的負數。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM