[英]How to use tf.gather in combination with tf.where
The following is not working due to the shape of tf.where()
.由于tf.where()
的形状,以下内容不起作用。 Is there a nice way to fix this?有解决这个问题的好方法吗?
I want the values of tensor_y
where tensor_x
fulfills a condition (eg == value ).我想要tensor_y
的值,其中tensor_x
满足条件(例如 == value )。 Important, the tensors have batch_dims = 1
.重要的是,张量有batch_dims = 1
。
tensor_x = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.int32)
tensor_y = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.int32)
trues = tf.where(tensor_y ==1)
new_tensor = tf.gather(tensor_y, axis=-1, indices = trues,batch_dims=1)
What I am doing now works*, but it is not so efficient I think:我现在正在做的工作*,但我认为效率不高:
new_tensor = tf.stack([tf.gather(tensor_y[i,:], tf.where(tensor_x[i,:] == 1)) for i in range(tensor_x.shape[0])])
*sometimes (I don't know under which conditions) I get his error: *有时(我不知道在什么情况下)我得到他的错误:
Shapes of all inputs must match: values[0].shape = [3,1].= values[1],shape = [6:1] [Op:Pack] name: stack
所有输入的形状必须匹配: values[0].shape = [3,1].= values[1],shape = [6:1] [Op:Pack] name: stack
Is this what you need?这是你需要的吗?
tensor_x = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.int32)
tensor_y = tf.random.uniform(shape=[2, 10], minval=-1, maxval=2, dtype=tf.int32)
new_tensor = tensor_y[(tensor_x==1)]
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.