[英]Filtering tensors element in Tensorflow
What's the equivalent operation in Tensorflow for this? Tensorflow 中的等效操作是什么? For example, I have a
x = np.array([-12,4,6,8,100])
.例如,我有一个
x = np.array([-12,4,6,8,100])
。 I want to do as simple as this: x = x[x>5]
, but I can't find any TF operation for this.我想做这样简单的事情:
x = x[x>5]
,但我找不到任何 TF 操作。 Thanks!谢谢!
In TF
you can do something like this to achieve similar results.在
TF
中,您可以执行类似的操作来获得类似的结果。
import numpy as np
import tensorflow as tf
x = np.array([-12,4,6,8,100])
y = tf.gather(x, tf.where(x > 5))
y.numpy().reshape(-1)
array([ 6, 8, 100])
Details细节
The tf.where
will return the indices of condition
that are True
. tf.where
将返回condition
索引为True
。 Such as如
x = np.array([-12,4,6,8,100])
tf.where(x > 5)
<tf.Tensor: shape=(3, 1), dtype=int64, numpy=
array([[2],
[3],
[4]])>
And next, using tf.gather
, it slices from params ( x
) axis according to indices (from tf.where
).接下来,使用
tf.gather
,它根据索引(来自tf.where
)从参数( x
)轴切片。 Such as如
tf.gather(x, tf.where(x > 5))
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.