簡體   English   中英

Tensorflow:使用 tf.where() 時如何保持批量維度?

[英]Tensorflow: how to keep batch dimension when using tf.where()?

我正在嘗試 select 元素不同於零,然后再使用它們。 我的輸入張量具有批次維度,因此我想保留它並且不要將數據混合到批次中。 我認為tf.gather_nd()對我有用,但首先我必須獲取所需數據的索引,然后我找到tf.where() 我嘗試了以下方法:

img = tf.constant([[[1., 0., 0.], 
                    [0., 0., 2.],
                    [0., 3, 0.]], 
                   [[1., 2., 3.], 
                    [0., 0., 1.], 
                    [0., 0., 0.]]], dtype='float32') # shape [2, 3, 3]

indexes = tf.where(tf.not_equal(img, 0.))

我希望indexes保持批量維度,但是它的形狀為[7, 2] 我懷疑問題出在不同批次中滿足條件的點數不同。

有沒有辦法讓索引保持批量維度? 提前致謝。

編輯: indexes的形狀為[7, 3] ,其中第一個暗淡指的是點數,第二個暗淡指的是該點的 position(包括它屬於哪個批次)。 但是我需要indexes具有特定的批處理維度,因為稍后我想用它來收集來自img的數據:

Y = tf.gather_nd(img, indexes)

我希望Y具有批次維度,但由於indexes沒有,我得到一個平坦的張量,其中混合了來自不同批次的數據。

實際上,您可能做錯了什么:運行代碼時, indexes的尺寸為(7,3)而不是(7,2) 3對應於3維,而7對應於img非零元素的數量。

sess.run(indexes)完整結果:

array([[0, 0, 0],
      [0, 1, 2],
      [0, 2, 1],
      [1, 0, 0],
      [1, 0, 1],
      [1, 0, 2],
      [1, 1, 2]])

您可以使用tf.math.top_k()根據其值從輸入中批量獲取值和索引,然后計算掩碼並將其應用於值和索引。

img = tf.constant([[[1., 0., 0.], 
                    [0., 0., 2.],
                    [0., 3, 0.]], 
                   [[1., 2., 3.], 
                    [0., 0., 1.], 
                    [0., 0., 0.]]], dtype='float32') # shape [2, 3, 3]

values, indices = tf.math.top_k(img, k=3)
# values would be 
# [[[1., 0., 0.], [2., 0., 0.], [3., 0., 0.]],
#  [[3., 2., 1.], [1., 0., 0.], [0., 0., 0.]]]
# indices would be
# [[[0, 1, 2], [2, 0, 1], [1, 0, 2]],
#  [[2, 1, 0], [2, 0, 1], [0, 1, 2]]]

mask = tf.cast(values, dtype=tf.bool)
# mask would be
# [[[True, False, False], [True, False, False], [True, False, False]],
#  [[True, True, True], [True, False, False], [False, False, False]]]

現在,您可以使用valuesmask獲得img的非零值,也可以使用indicesmask獲得img的非零索引。 您可以使用tf.gather()imgindices獲取值,如下所示:

values2 = tf.gather(img, indices, batch_dims=2)
# values2 will be same with the above values

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM