简体   繁体   English

如何结合使用 tf.gather 和 tf.where

[英]How to use tf.gather in combination with tf.where

The following is not working due to the shape of tf.where() .由于tf.where()的形状,以下内容不起作用。 Is there a nice way to fix this?有解决这个问题的好方法吗?

I want the values of tensor_y where tensor_x fulfills a condition (eg == value ).我想要tensor_y的值,其中tensor_x满足条件(例如 == value )。 Important, the tensors have batch_dims = 1 .重要的是,张量有batch_dims = 1

tensor_x = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.int32)
tensor_y = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.int32)
trues = tf.where(tensor_y ==1)


new_tensor = tf.gather(tensor_y, axis=-1, indices = trues,batch_dims=1)


What I am doing now works*, but it is not so efficient I think:我现在正在做的工作*,但我认为效率不高:

new_tensor = tf.stack([tf.gather(tensor_y[i,:], tf.where(tensor_x[i,:] == 1)) for i in range(tensor_x.shape[0])])

*sometimes (I don't know under which conditions) I get his error: *有时(我不知道在什么情况下)我得到他的错误:

Shapes of all inputs must match: values[0].shape = [3,1].= values[1],shape = [6:1] [Op:Pack] name: stack所有输入的形状必须匹配: values[0].shape = [3,1].= values[1],shape = [6:1] [Op:Pack] name: stack

Is this what you need?这是你需要的吗?

tensor_x = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.int32)
tensor_y = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.int32)

new_tensor = tensor_y[(tensor_x==1)]

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM