[英]Incorporating dim parameter of torch.topk in tf.nn.top_k
Pytorch 提供torch.topk(input, k, dim=None, largest=True, sorted=True)
函數來計算給定input
張量沿給定維度dim
k
最大元素。
我有一個形狀(16, 512, 4096)
的張量,我以下列方式使用torch.topk
# inputs.shape (16L, 512L, 4096L)
dist, idx = torch.topk(inputs, 64, dim=2, largest=False, sorted=False)
# dist.shape (16L, 512L, 64L), idx.shape (16L, 512L, 64L)
我發現類似的 tensorflow 實現如下 - tf.nn.top_k(input, k=1, sorted=True, name=None)
。
我的問題是如何在tf.nn.top_k
中tf.nn.top_k
dim=2
參數,從而實現與 pytorch 計算出的形狀相同的張量?
tf.nn.top_k
處理輸入的最后一個維度。 這意味着它應該像您的示例一樣工作:
dist, idx = tf.nn.top_k(inputs, 64, sorted=False)
一般來說,你可以想象 Tensorflow 版本像 Pytorch 版本一樣工作,硬編碼dim=-1
,即最后一個維度。
但是,看起來您實際上想要 k 個最小的元素。 在這種情況下,我們可以做
dist, idx = tf.nn.top_k(-1*inputs, 64, sorted=False)
dist = -1*dist
所以我們取k個最大的負輸入,也就是原始輸入中k個最小的。 然后我們反轉值的負數。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.