[英]Tensorflow: how to keep batch dimension when using tf.where()?
我正在嘗試 select 元素不同於零,然后再使用它們。 我的輸入張量具有批次維度,因此我想保留它並且不要將數據混合到批次中。 我認為tf.gather_nd()
對我有用,但首先我必須獲取所需數據的索引,然后我找到tf.where()
。 我嘗試了以下方法:
img = tf.constant([[[1., 0., 0.],
[0., 0., 2.],
[0., 3, 0.]],
[[1., 2., 3.],
[0., 0., 1.],
[0., 0., 0.]]], dtype='float32') # shape [2, 3, 3]
indexes = tf.where(tf.not_equal(img, 0.))
我希望indexes
保持批量維度,但是它的形狀為[7, 2]
。 我懷疑問題出在不同批次中滿足條件的點數不同。
有沒有辦法讓索引保持批量維度? 提前致謝。
編輯: indexes
的形狀為[7, 3]
,其中第一個暗淡指的是點數,第二個暗淡指的是該點的 position(包括它屬於哪個批次)。 但是我需要indexes
具有特定的批處理維度,因為稍后我想用它來收集來自img
的數據:
Y = tf.gather_nd(img, indexes)
我希望Y
具有批次維度,但由於indexes
沒有,我得到一個平坦的張量,其中混合了來自不同批次的數據。
實際上,您可能做錯了什么:運行代碼時, indexes
的尺寸為(7,3)
而不是(7,2)
。 3
對應於3維,而7
對應於img
非零元素的數量。
sess.run(indexes)
完整結果:
array([[0, 0, 0],
[0, 1, 2],
[0, 2, 1],
[1, 0, 0],
[1, 0, 1],
[1, 0, 2],
[1, 1, 2]])
您可以使用tf.math.top_k()
根據其值從輸入中批量獲取值和索引,然后計算掩碼並將其應用於值和索引。
img = tf.constant([[[1., 0., 0.],
[0., 0., 2.],
[0., 3, 0.]],
[[1., 2., 3.],
[0., 0., 1.],
[0., 0., 0.]]], dtype='float32') # shape [2, 3, 3]
values, indices = tf.math.top_k(img, k=3)
# values would be
# [[[1., 0., 0.], [2., 0., 0.], [3., 0., 0.]],
# [[3., 2., 1.], [1., 0., 0.], [0., 0., 0.]]]
# indices would be
# [[[0, 1, 2], [2, 0, 1], [1, 0, 2]],
# [[2, 1, 0], [2, 0, 1], [0, 1, 2]]]
mask = tf.cast(values, dtype=tf.bool)
# mask would be
# [[[True, False, False], [True, False, False], [True, False, False]],
# [[True, True, True], [True, False, False], [False, False, False]]]
現在,您可以使用values
和mask
獲得img
的非零值,也可以使用indices
和mask
獲得img
的非零索引。 您可以使用tf.gather()
從img
和indices
獲取值,如下所示:
values2 = tf.gather(img, indices, batch_dims=2)
# values2 will be same with the above values
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.