[英]In tensorflow, how to use tf.where() on y to select rows from x where x and y have different shapes?
我有一個大小為 x = 32x512 的矩陣和一個大小為 y = 32x1 的 label 向量。 標簽在 0-3 范圍內。 我需要來自 x 的 select 行,在相應的 y 中有 label 1。
我嘗試使用以下命令: temp_maps = tf.where( tf.equal(y,1), x, tf.math.scalar_mul(0,x) ) 但這給了我這個錯誤:
ValueError:尺寸必須相等,但對於 op:SelectV2 輸入形狀為 32 和 512:[32]、[32,512]、[32,512]
我想要的是 x 中標簽為 1 的行。 我正在使用 tf.math.scalar_mul(0,x) 因為在條件為 false 的情況下必須選擇某些東西,所以我選擇了一個零張量。
創建虛擬矩陣:
import tensorflow as tf
tf.enable_eager_execution()
import numpy as np
B = tf.convert_to_tensor(np.random.randint(0, 3, 32).reshape((32, 1)))
A = tf.convert_to_tensor(np.arange(32*512).reshape((32, 512)))
使用tf.equal
獲取所有 2-labels 的 boolean 張量:
eq = tf.equal(B, 2)
In [16]: print(eq)
tf.Tensor(
[[ True]
[False]
[False]
[False]
[False]
[False]
[False]
[False]
[False]
[ True]
[ True]
[False]
[ True]
[False]
[False]
[False]
[False]
[False]
[ True]
[False]
[False]
[False]
[False]
[False]
[ True]
[False]
[ True]
[False]
[False]
[ True]
[False]
[False]], shape=(32, 1), dtype=bool)
現在您可以使用tf.where
來獲取位置索引:
In [19]: tf.where(eq)
Out[19]:
<tf.Tensor: id=45, shape=(8, 2), dtype=int64, numpy=
array([[ 0, 0],
[ 9, 0],
[10, 0],
[12, 0],
[18, 0],
[24, 0],
[26, 0],
[29, 0]])>
如果您想獲得 A 的一部分,可以使用tf.gather
:
In [30]: tf.gather(A, tf.where(tf.equal(B, 2))[:, 0])
Out[30]:
<tf.Tensor: id=105, shape=(8, 512), dtype=int64, numpy=
array([[ 0, 1, 2, ..., 509, 510, 511],
[ 4608, 4609, 4610, ..., 5117, 5118, 5119],
[ 5120, 5121, 5122, ..., 5629, 5630, 5631],
...,
[12288, 12289, 12290, ..., 12797, 12798, 12799],
[13312, 13313, 13314, ..., 13821, 13822, 13823],
[14848, 14849, 14850, ..., 15357, 15358, 15359]])>
我認為接受的答案無效,因為這是 a)memory 效率低下 b)圖形執行不兼容的方法。
由於我一直在努力解決同樣的問題,我必須找到一種方法來克服這個問題:
tf.where( tf.expand_dims(tf.equal(y,1), 1) , x , tf.math.scalar_mul(0,x) )
因此第一個布爾張量的形狀為 [32, 1],可廣播到 [32,512]。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.