簡體   English   中英

在 tensorflow 中,如何在 y 上使用 tf.where() 到來自 x 的 select 行,其中 x 和 y 具有不同的形狀?

[英]In tensorflow, how to use tf.where() on y to select rows from x where x and y have different shapes?

我有一個大小為 x = 32x512 的矩陣和一個大小為 y = 32x1 的 label 向量。 標簽在 0-3 范圍內。 我需要來自 x 的 select 行,在相應的 y 中有 label 1。

我嘗試使用以下命令: temp_maps = tf.where( tf.equal(y,1), x, tf.math.scalar_mul(0,x) ) 但這給了我這個錯誤:

ValueError:尺寸必須相等,但對於 op:SelectV2 輸入形狀為 32 和 512:[32]、[32,512]、[32,512]

我想要的是 x 中標簽為 1 的行。 我正在使用 tf.math.scalar_mul(0,x) 因為在條件為 false 的情況下必須選擇某些東西,所以我選擇了一個零張量。

創建虛擬矩陣:

        import tensorflow as tf
        tf.enable_eager_execution()
        import numpy as np
        B = tf.convert_to_tensor(np.random.randint(0, 3, 32).reshape((32, 1)))
        A = tf.convert_to_tensor(np.arange(32*512).reshape((32, 512)))

使用tf.equal獲取所有 2-labels 的 boolean 張量:

    eq = tf.equal(B, 2)
    In [16]: print(eq)
    tf.Tensor(
    [[ True]
     [False]
     [False]
     [False]
     [False]
     [False]
     [False]
     [False]
     [False]
     [ True]
     [ True]
     [False]
     [ True]
     [False]
     [False]
     [False]
     [False]
     [False]
     [ True]
     [False]
     [False]
     [False]
     [False]
     [False]
     [ True]
     [False]
     [ True]
     [False]
     [False]
     [ True]
     [False]
     [False]], shape=(32, 1), dtype=bool)

現在您可以使用tf.where來獲取位置索引:

In [19]: tf.where(eq)
Out[19]: 
<tf.Tensor: id=45, shape=(8, 2), dtype=int64, numpy=
array([[ 0,  0],
       [ 9,  0],
       [10,  0],
       [12,  0],
       [18,  0],
       [24,  0],
       [26,  0],
       [29,  0]])>

如果您想獲得 A 的一部分,可以使用tf.gather

In [30]: tf.gather(A, tf.where(tf.equal(B, 2))[:, 0])
Out[30]: 
<tf.Tensor: id=105, shape=(8, 512), dtype=int64, numpy=
array([[    0,     1,     2, ...,   509,   510,   511],
       [ 4608,  4609,  4610, ...,  5117,  5118,  5119],
       [ 5120,  5121,  5122, ...,  5629,  5630,  5631],
       ...,
       [12288, 12289, 12290, ..., 12797, 12798, 12799],
       [13312, 13313, 13314, ..., 13821, 13822, 13823],
       [14848, 14849, 14850, ..., 15357, 15358, 15359]])>

我認為接受的答案無效,因為這是 a)memory 效率低下 b)圖形執行不兼容的方法。

由於我一直在努力解決同樣的問題,我必須找到一種方法來克服這個問題:

tf.where( tf.expand_dims(tf.equal(y,1), 1) , x , tf.math.scalar_mul(0,x) )

因此第一個布爾張量的形狀為 [32, 1],可廣播到 [32,512]。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM