簡體   English   中英

在 tf.nn.top_k 中加入 torch.topk 的 dim

[英]Incorporating dim of torch.topk in tf.nn.top_k

Pytorch 提供torch.topk(input, k, dim=None, largest=True, sorted=True) function 來計算給定input張量沿給定維度dimk個最大元素。

我有一個形狀(64, 128, 512)的張量,我正在以下列方式使用torch.topk -

reduce = input.topk(k, dim=1).values

我發現類似的 tensorflow 實現如下 - tf.nn.top_k(input, k=1, sorted=True, name=None)

我的問題是如何在tf.nn.top_k中加入dim=1參數,以實現與 pytorch 計算的形狀相同的張量?

我同意@jodag,你將不得不轉置或重塑你的張量,因為tf.math.top_k總是在最后一個維度上工作。

您還可以做的是首先獲取張量中沿某個維度的所有最大值,然后從該張量中獲取前k個值:

import tensorflow as tf
tf.random.set_seed(2)

k = 3
tensor = tf.random.uniform((2, 4, 6), maxval=10, dtype=tf.int32)
max_tensor = tf.reduce_max(tensor, axis=1)
k_max_tensor = tf.math.top_k(max_tensor, k=k, sorted=True).values

print('Original tensor --> ', tensor)
print('Max tensor --> ', max_tensor)
print('K-Max tensor --> ', k_max_tensor)
print('Unique K-Max tensor', tf.unique(tf.reshape(k_max_tensor, (tf.math.reduce_prod(tf.shape(k_max_tensor)), ))).y)
Original tensor -->  tf.Tensor(
[[[1 6 2 7 3 6]
  [7 5 1 1 0 6]
  [9 1 3 9 1 4]
  [6 0 6 2 4 0]]

 [[4 6 8 2 4 7]
  [5 0 8 2 8 9]
  [0 2 0 0 9 8]
  [9 3 8 9 0 6]]], shape=(2, 4, 6), dtype=int32)
Max tensor -->  tf.Tensor(
[[9 6 6 9 4 6]
 [9 6 8 9 9 9]], shape=(2, 6), dtype=int32)
K-Max tensor -->  tf.Tensor(
[[9 9 6]
 [9 9 9]], shape=(2, 3), dtype=int32)
Unique K-Max tensor tf.Tensor([9 6], shape=(2,), dtype=int32)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM