[英]How to use tf.gather in combination with tf.where
由於tf.where()
的形狀,以下內容不起作用。 有解決這個問題的好方法嗎?
我想要tensor_y
的值,其中tensor_x
滿足條件(例如 == value )。 重要的是,張量有batch_dims = 1
。
tensor_x = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.int32)
tensor_y = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.int32)
trues = tf.where(tensor_y ==1)
new_tensor = tf.gather(tensor_y, axis=-1, indices = trues,batch_dims=1)
我現在正在做的工作*,但我認為效率不高:
new_tensor = tf.stack([tf.gather(tensor_y[i,:], tf.where(tensor_x[i,:] == 1)) for i in range(tensor_x.shape[0])])
*有時(我不知道在什么情況下)我得到他的錯誤:
所有輸入的形狀必須匹配: values[0].shape = [3,1].= values[1],shape = [6:1] [Op:Pack] name: stack
這是你需要的嗎?
tensor_x = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.int32)
tensor_y = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.int32)
new_tensor = tensor_y[(tensor_x==1)]
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.