繁体   English   中英

在 tensorflow 中,如何在 y 上使用 tf.where() 到来自 x 的 select 行,其中 x 和 y 具有不同的形状?

[英]In tensorflow, how to use tf.where() on y to select rows from x where x and y have different shapes?

我有一个大小为 x = 32x512 的矩阵和一个大小为 y = 32x1 的 label 向量。 标签在 0-3 范围内。 我需要来自 x 的 select 行,在相应的 y 中有 label 1。

我尝试使用以下命令: temp_maps = tf.where( tf.equal(y,1), x, tf.math.scalar_mul(0,x) ) 但这给了我这个错误:

ValueError:尺寸必须相等,但对于 op:SelectV2 输入形状为 32 和 512:[32]、[32,512]、[32,512]

我想要的是 x 中标签为 1 的行。 我正在使用 tf.math.scalar_mul(0,x) 因为在条件为 false 的情况下必须选择某些东西,所以我选择了一个零张量。

创建虚拟矩阵:

        import tensorflow as tf
        tf.enable_eager_execution()
        import numpy as np
        B = tf.convert_to_tensor(np.random.randint(0, 3, 32).reshape((32, 1)))
        A = tf.convert_to_tensor(np.arange(32*512).reshape((32, 512)))

使用tf.equal获取所有 2-labels 的 boolean 张量:

    eq = tf.equal(B, 2)
    In [16]: print(eq)
    tf.Tensor(
    [[ True]
     [False]
     [False]
     [False]
     [False]
     [False]
     [False]
     [False]
     [False]
     [ True]
     [ True]
     [False]
     [ True]
     [False]
     [False]
     [False]
     [False]
     [False]
     [ True]
     [False]
     [False]
     [False]
     [False]
     [False]
     [ True]
     [False]
     [ True]
     [False]
     [False]
     [ True]
     [False]
     [False]], shape=(32, 1), dtype=bool)

现在您可以使用tf.where来获取位置索引:

In [19]: tf.where(eq)
Out[19]: 
<tf.Tensor: id=45, shape=(8, 2), dtype=int64, numpy=
array([[ 0,  0],
       [ 9,  0],
       [10,  0],
       [12,  0],
       [18,  0],
       [24,  0],
       [26,  0],
       [29,  0]])>

如果您想获得 A 的一部分,可以使用tf.gather

In [30]: tf.gather(A, tf.where(tf.equal(B, 2))[:, 0])
Out[30]: 
<tf.Tensor: id=105, shape=(8, 512), dtype=int64, numpy=
array([[    0,     1,     2, ...,   509,   510,   511],
       [ 4608,  4609,  4610, ...,  5117,  5118,  5119],
       [ 5120,  5121,  5122, ...,  5629,  5630,  5631],
       ...,
       [12288, 12289, 12290, ..., 12797, 12798, 12799],
       [13312, 13313, 13314, ..., 13821, 13822, 13823],
       [14848, 14849, 14850, ..., 15357, 15358, 15359]])>

我认为接受的答案无效,因为这是 a)memory 效率低下 b)图形执行不兼容的方法。

由于我一直在努力解决同样的问题,我必须找到一种方法来克服这个问题:

tf.where( tf.expand_dims(tf.equal(y,1), 1) , x , tf.math.scalar_mul(0,x) )

因此第一个布尔张量的形状为 [32, 1],可广播到 [32,512]。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM