簡體   English   中英

Tensorflow 如何檢查張量行是否僅為零?

[英]Tensorflow how to check if a tensor row is only zeroes?

我正在訓練一個簡單的網絡來預測單個對象的邊界框坐標。 然而,有些圖片找不到對象。 由於網絡總是進行預測,因此它還會預測一個介於 0 和 1 之間的置信度值,該值應該表示圖片中存在物體的概率。 我的預測張量稱為logits ,它是一個(batch_size, 5)張量(置信度、x、y、寬度和高度)。 同樣, labels張量也是(batch_size, 5)

以前我只用總是有對象的圖像進行訓練,所以我基本上可以做到

loss = tf.l2_loss(logits - labels)

我也想用沒有對象的圖片開始訓練,當圖片中沒有對象時,我不希望網絡因為它預測的任何坐標而受到懲罰。 在這種情況下,重要的是置信度值,它應該接近 0(沒有對象)。

我應該如何構建我的標簽和損失函數來實現這一點? 我可以將沒有對象的圖像標簽設置為全零,但如何檢查特定行是否僅為零? 在這種情況下,logits 中的相應行也需要設置為零(置信度值除外!),以便坐標引起的損失也為零。

您可以使用tf.math.count_nonzero()來檢查張量是否全為零。 您可以在此處查看指南。

一種查找張量行是否全為零的方法:

import tensorflow as tf

image = tf.fill([8,8], 0)
sess = tf.Session()
sess.run(tf.initialize_all_variables())
image_row = tf.slice(image, [1,0], [1, -1])
total = tf.reduce_sum(tf.abs(image_row))
is_all_zero = tf.equal(total, 0)
print sess.run([total, is_all_zero, image_row])

結果:

[0, True, array([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)]

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM