簡體   English   中英

Tensorflow:在 tensorflow 中創建等效的 torch.gather()

[英]Tensorflow: Create the torch.gather() equivalent in tensorflow

我想在 TensorFlow 2.X 中復制torch.gather() function。 我有一個張量A (形狀:[2,4,3 [2, 4, 3] )和一個相應的索引張量I (形狀: [2,2,3] )。 使用torch.gather()產生以下結果:

A = torch.tensor([[[10,20,30], [100,200,300], [1000,2000,3000]],
                  [[50,60,70], [500,600,700], [5000,6000,7000]]])
I = torch.tensor([[[0,1,0], [1,2,1]],
                  [[2,1,2], [1,0,1]]])
torch.gather(A, 1, I)

>
tensor([[[10,   200,   30], [100, 2000, 300]],
         [5000, 600, 7000], [500,   60, 700]]])

我試過使用tf.gather() ,但這並沒有產生類似 pytorch 的結果。 我也試過玩tf.gather_nd() ,但我找不到合適的解決方案。

我找到了這篇StackOverflow 帖子,但這似乎對我不起作用。

編輯:使用tf.gather_nd(A, I)時,我得到以下結果:

tf.gather_nd(A, I)

>
[[100, 6000],
 [  0,   60]]

tf.gather(A, I)的結果相當長。 它的形狀為[2, 2, 3, 4, 3]

torch.gathertf.gather_nd工作方式不同,因此在使用相同的索引張量時會產生不同的結果(在某些情況下也會返回錯誤)。 這是指數張量必須看起來像才能獲得相同的結果:

import tensorflow as tf

A = tf.constant([[
                   [10,20,30], [100,200,300], [1000,2000,3000]],
                  [[50,60,70], [500,600,700], [5000,6000,7000]]])
I = tf.constant([[[
                  [0,0,0],
                  [0,1,1], 
                  [0,0,2],
                ],[
                  [0,1,0],
                  [0,2,1],
                  [0,1,2],  
                ]], 
                 [[
                  [1,2,0],
                  [1,1,1],
                  [1,2,2],  
                ], 
                  [
                  [1,1,0],
                  [1,0,1],
                  [1,1,2],  
                ]]])


print(tf.gather_nd(A, I))
tf.Tensor(
[[[  10  200   30]
  [ 100 2000  300]]

 [[5000  600 7000]
  [ 500   60  700]]], shape=(2, 2, 3), dtype=int32)

所以,問題實際上是你如何計算你的指數,或者它們總是硬編碼的? 另外,請查看這篇關於這兩種操作的區別的帖子

至於您鏈接的帖子對您不起作用,您只需要輸入索引,一切都應該沒問題:

def torch_gather(x, indices, gather_axis):

    all_indices = tf.where(tf.fill(indices.shape, True))
    gather_locations = tf.reshape(indices, [indices.shape.num_elements()])

    gather_indices = []
    for axis in range(len(indices.shape)):
        if axis == gather_axis:
            gather_indices.append(tf.cast(gather_locations, dtype=tf.int64))
        else:
            gather_indices.append(tf.cast(all_indices[:, axis], dtype=tf.int64))

    gather_indices = tf.stack(gather_indices, axis=-1)
    gathered = tf.gather_nd(x, gather_indices)
    reshaped = tf.reshape(gathered, indices.shape)
    return reshaped

I = tf.constant([[[0,1,0], [1,2,1]],
                  [[2,1,2], [1,0,1]]])
A = tf.constant([[
                   [10,20,30], [100,200,300], [1000,2000,3000]],
                  [[50,60,70], [500,600,700], [5000,6000,7000]]])
print(torch_gather(A, I, 1))
tf.Tensor(
[[[  10  200   30]
  [ 100 2000  300]]

 [[5000  600 7000]
  [ 500   60  700]]], shape=(2, 2, 3), dtype=int32)

你也可以試試這個相當於 torch.gather:

import random
import numpy as np
import tensorflow as tf
import torch

# torch.gather equivalent
def tf_gather(x: tf.Tensor, indices: tf.Tensor, axis: int) -> tf.Tensor:
    complete_indices = np.array(np.where(indices > -1))
    complete_indices[axis] = tf.reshape(indices, [-1])
    flat_ind = np.ravel_multi_index(tuple(complete_indices), x.shape)
    return tf.reshape(tf.gather(tf.reshape(x, [-1]), flat_ind), indices.shape)


# ======= test program ========
if __name__ == '__main__':

    a = np.random.rand(2, 5, 3, 4)
    dim = 2  # 0 <= dim < len(a.shape))

    ind = np.expand_dims(np.argmax(a, axis=dim), axis=dim)

    # ========== np: groundtruth ==========
    np_max = np.expand_dims(np.max(a, axis=dim), axis=dim)

    # ========= torch: gather =========
    torch_max = torch.gather(torch.tensor(a), dim=dim, index=torch.tensor(ind))

    # ========= tensorflow: torch-like gather =========
    tf_max = tf_gather(tf.convert_to_tensor(a), axis=dim, indices=tf.convert_to_tensor(ind))

    keepdim = False
    if not keepdim:
        np_max = np.squeeze(np_max, axis=dim)
        torch_max = torch.squeeze(torch_max, dim=dim)
        tf_max = tf.squeeze(tf_max, axis=dim)

    # print('np_max:\n', np_max)
    # print('torch_max:\n', torch_max)
    # print('tf_max:\n', tf_max)

    assert np.allclose(np_max, torch_max.numpy()), '\33[1m\33[31mError with torch\33[0m'
    assert np.allclose(np_max, tf_max.numpy()), '\33[1m\33[31mError with tensorflow\33[0m'

    print('\33[1m\33[32mSuccess!\33[0m')

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM