簡體   English   中英

tensorflow 相當於torch.gather

[英]tensorflow equivalent of torch.gather

我有一個形狀為(16, 4096, 3)的張量。 我有另一個形狀指數張量(16, 32768, 3) 我正在嘗試沿dim=1收集值。 這最初是在 pytorch 中使用 Gather函數完成的,如下所示 -

# a.shape (16L, 4096L, 3L)
# idx.shape (16L, 32768L, 3L)
b = a.gather(1, idx)
# b.shape (16L, 32768L, 3L)

請注意,輸出b的大小與idx的大小相同。 但是,當我應用 tensorflow 的gather功能時,我得到了完全不同的輸出。 發現輸出維度不匹配,如下所示-

b = tf.gather(a, idx, axis=1)
# b.shape (16, 16, 32768, 3, 3)

我也嘗試過使用tf.gather_nd但徒勞無功。 見下文-

b = tf.gather_nd(a, idx)
# b.shape (16, 32768)

為什么我會得到不同形狀的張量? 我想獲得與pytorch計算出的形狀相同的張量。

換句話說,我想知道 tensorflow 相當於 torch.gather。

對於二維情況,有一種方法可以做到:

# a.shape (16L, 10L)
# idx.shape (16L,1)
idx = tf.stack([tf.range(tf.shape(idx)[0]),idx[:,0]],axis=-1)
b = tf.gather_nd(a,idx)

但是,對於 ND 情況,這種方法可能非常復雜

這“應該”是使用 tf.gather_nd 的通用解決方案(我只測試了沿最后一個軸的 2 級和 3 級張量):

def torch_gather(x, indices, gather_axis):
    # if pytorch gather indices are
    # [[[0, 10, 20], [0, 10, 20], [0, 10, 20]],
    #  [[0, 10, 20], [0, 10, 20], [0, 10, 20]]]
    # tf nd_gather needs to be
    # [[0,0,0], [0,0,10], [0,0,20], [0,1,0], [0,1,10], [0,1,20], [0,2,0], [0,2,10], [0,2,20],
    #  [1,0,0], [1,0,10], [1,0,20], [1,1,0], [1,1,10], [1,1,20], [1,2,0], [1,2,10], [1,2,20]]

    # create a tensor containing indices of each element
    all_indices = tf.where(tf.fill(indices.shape, True))
    gather_locations = tf.reshape(indices, [indices.shape.num_elements()])

    # splice in our pytorch style index at the correct axis
    gather_indices = []
    for axis in range(len(indices.shape)):
        if axis == gather_axis:
            gather_indices.append(gather_locations)
        else:
            gather_indices.append(all_indices[:, axis])

    gather_indices = tf.stack(gather_indices, axis=-1)
    gathered = tf.gather_nd(x, gather_indices)
    reshaped = tf.reshape(gathered, indices.shape)
    return reshaped

對於最后一個軸的收集,我們可以使用一般 ND 情況下的 2D-reshape 技巧,然后使用上面的@LiShaoyuan 2D 代碼

        # last-axis gathering only - use 2D-reshape-trick for Torch's style nD gathering
        def torch_gather(param, id_tensor):

            # 2d-gather torch equivalent from @LiShaoyuan above 
            def gather2d(target, id_tensor):
                idx = tf.stack([tf.range(tf.shape(id_tensor)[0]),id_tensor[:,0]],axis=-1)
                result = tf.gather_nd(target,idx)
                return tf.expand_dims(result,axis=-1)

            target = tf.reshape(param, (-1, param.shape[-1])) # reshape 2D
            target_shape = id_tensor.shape

            id_tensor = tf.reshape(id_tensor, (-1, 1)) # also 2D-index
            result = gather2d(target, id_tensor)
            return tf.reshape(result, target_shape)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM