簡體   English   中英

如何在 Tensorflow 張量中僅舍入 k 個最大元素

[英]how to round up only k greatest elements in a Tensorflow tensor

假設有一個 TensorFlow 張量 - 例如 [0.1,0.2,0.3,0.4]。 我想四舍五入k最大元素並向下舍入 rest。 (例如,當k =2時,我希望得到[0,0,1,1]。當k =3時,我希望得到[0,1,1,1]。)

我想只使用 TensorFlow 操作來實現這個 function 。 我如何實現這一目標?

嘗試這樣的事情:

import tensorflow as tf

x = tf.constant([0.1,0.2,0.3,0.4])
k = 3
greatest = tf.math.top_k(x, k=k).indices 
tensor = tf.tensor_scatter_nd_update(tf.zeros_like(x), tf.reshape(greatest, (tf.shape(greatest)[0], 1)), tf.ones_like(tf.gather(x, greatest)))

k = 3:

tf.Tensor([0. 1. 1. 1.], shape=(4,), dtype=float32)

k = 2:

tf.Tensor([0. 0. 1. 1.], shape=(4,), dtype=float32)

這種方法並沒有真正舍入,因為將0.30.4舍入到最接近的 integer 會導致零,這不是您想要的。 因此,我只需將張量中的最高k值轉換為 1,並將 rest 轉換為零,但如果它仍然是二進制分類,這對於您的用例來說應該足夠了。

如果您真的想四舍五入最大的k值,請使用tf.math.ceil而不是tf.ones_like

tensor = tf.tensor_scatter_nd_update(tf.zeros_like(x), tf.reshape(greatest, (tf.shape(greatest)[0], 1)), tf.ceil((tf.gather(x, greatest))))

您可以為此使用tf.math.top_k function 將返回給定張量中 k 最大元素的值和索引。

https://www.tensorflow.org/api_docs/python/tf/math/top_k

然后,您可以使用返回的索引將張量中的值設置為特定值。

以下解決方案將問題中提到的值四舍五入。

import tensorflow as tf

x = tf.constant([0.1,0.2,0.3,0.4])
k = 3

# retrieve min and max values
max_value = tf.math.ceil(tf.math.reduce_max(x))
min_value = tf.math.floor(tf.math.reduce_min(x))

# retrieve the k largest elements
k_largest = tf.math.top_k(x, k=k)

# reshape the indices, required for ‘scatter‘ function
indices = tf.reshape(k_largest.indices, (-1,1))
values = k_largest.values

# initialize update tensor with max_value
updates = max_value * tf.ones_like(values)
# initialize result with min_value
x_new = min_value * tf.ones_like(x)
# update values for k_largest indices
x_new = tf.tensor_scatter_nd_update(
    x_new, indices, updates)

print(x_new)

如果您要求的ceilfloor操作應針對每個元素應用,而不是應用於張量內的minmax ,則如下所示:

import tensorflow as tf

x = tf.constant([0.1,0.2,0.3,0.4])
k = 3

# retrieve the k largest elements
k_largest = tf.math.top_k(x, k=k)
# reshape the indices, required for ‘scatter‘ function
indices = tf.reshape(k_largest.indices, (-1,1))

# get floored values
floored_values = tf.math.floor(x)
# get ceiled values only for top-k
ceiled_values = tf.math.ceil(k_largest.values)

# initialize result with per element floored values
x_new = floored_values
# update values for k_largest indices with per element ceiled values
x_new = tf.tensor_scatter_nd_update(
    floored_values, indices, ceiled_values)

print(x_new)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM