簡體   English   中英

如何結合使用 tf.gather 和 tf.where

[英]How to use tf.gather in combination with tf.where

由於tf.where()的形狀,以下內容不起作用。 有解決這個問題的好方法嗎?

我想要tensor_y的值,其中tensor_x滿足條件(例如 == value )。 重要的是,張量有batch_dims = 1

tensor_x = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.int32)
tensor_y = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.int32)
trues = tf.where(tensor_y ==1)


new_tensor = tf.gather(tensor_y, axis=-1, indices = trues,batch_dims=1)


我現在正在做的工作*,但我認為效率不高:

new_tensor = tf.stack([tf.gather(tensor_y[i,:], tf.where(tensor_x[i,:] == 1)) for i in range(tensor_x.shape[0])])

*有時(我不知道在什么情況下)我得到他的錯誤:

所有輸入的形狀必須匹配: values[0].shape = [3,1].= values[1],shape = [6:1] [Op:Pack] name: stack

這是你需要的嗎?

tensor_x = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.int32)
tensor_y = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.int32)

new_tensor = tensor_y[(tensor_x==1)]

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM