[英]Tensorflow: Create the torch.gather() equivalent in tensorflow
我想在 TensorFlow 2.X 中復制torch.gather()
function。 我有一個張量A (形狀:[2,4,3 [2, 4, 3]
)和一個相應的索引張量I (形狀: [2,2,3]
)。 使用torch.gather()
產生以下結果:
A = torch.tensor([[[10,20,30], [100,200,300], [1000,2000,3000]],
[[50,60,70], [500,600,700], [5000,6000,7000]]])
I = torch.tensor([[[0,1,0], [1,2,1]],
[[2,1,2], [1,0,1]]])
torch.gather(A, 1, I)
>
tensor([[[10, 200, 30], [100, 2000, 300]],
[5000, 600, 7000], [500, 60, 700]]])
我試過使用tf.gather()
,但這並沒有產生類似 pytorch 的結果。 我也試過玩tf.gather_nd()
,但我找不到合適的解決方案。
我找到了這篇StackOverflow 帖子,但這似乎對我不起作用。
編輯:使用tf.gather_nd(A, I)
時,我得到以下結果:
tf.gather_nd(A, I)
>
[[100, 6000],
[ 0, 60]]
tf.gather(A, I)
的結果相當長。 它的形狀為[2, 2, 3, 4, 3]
torch.gather
和tf.gather_nd
工作方式不同,因此在使用相同的索引張量時會產生不同的結果(在某些情況下也會返回錯誤)。 這是指數張量必須看起來像才能獲得相同的結果:
import tensorflow as tf
A = tf.constant([[
[10,20,30], [100,200,300], [1000,2000,3000]],
[[50,60,70], [500,600,700], [5000,6000,7000]]])
I = tf.constant([[[
[0,0,0],
[0,1,1],
[0,0,2],
],[
[0,1,0],
[0,2,1],
[0,1,2],
]],
[[
[1,2,0],
[1,1,1],
[1,2,2],
],
[
[1,1,0],
[1,0,1],
[1,1,2],
]]])
print(tf.gather_nd(A, I))
tf.Tensor(
[[[ 10 200 30]
[ 100 2000 300]]
[[5000 600 7000]
[ 500 60 700]]], shape=(2, 2, 3), dtype=int32)
所以,問題實際上是你如何計算你的指數,或者它們總是硬編碼的? 另外,請查看這篇關於這兩種操作的區別的帖子。
至於您鏈接的帖子對您不起作用,您只需要輸入索引,一切都應該沒問題:
def torch_gather(x, indices, gather_axis):
all_indices = tf.where(tf.fill(indices.shape, True))
gather_locations = tf.reshape(indices, [indices.shape.num_elements()])
gather_indices = []
for axis in range(len(indices.shape)):
if axis == gather_axis:
gather_indices.append(tf.cast(gather_locations, dtype=tf.int64))
else:
gather_indices.append(tf.cast(all_indices[:, axis], dtype=tf.int64))
gather_indices = tf.stack(gather_indices, axis=-1)
gathered = tf.gather_nd(x, gather_indices)
reshaped = tf.reshape(gathered, indices.shape)
return reshaped
I = tf.constant([[[0,1,0], [1,2,1]],
[[2,1,2], [1,0,1]]])
A = tf.constant([[
[10,20,30], [100,200,300], [1000,2000,3000]],
[[50,60,70], [500,600,700], [5000,6000,7000]]])
print(torch_gather(A, I, 1))
tf.Tensor(
[[[ 10 200 30]
[ 100 2000 300]]
[[5000 600 7000]
[ 500 60 700]]], shape=(2, 2, 3), dtype=int32)
你也可以試試這個相當於 torch.gather:
import random
import numpy as np
import tensorflow as tf
import torch
# torch.gather equivalent
def tf_gather(x: tf.Tensor, indices: tf.Tensor, axis: int) -> tf.Tensor:
complete_indices = np.array(np.where(indices > -1))
complete_indices[axis] = tf.reshape(indices, [-1])
flat_ind = np.ravel_multi_index(tuple(complete_indices), x.shape)
return tf.reshape(tf.gather(tf.reshape(x, [-1]), flat_ind), indices.shape)
# ======= test program ========
if __name__ == '__main__':
a = np.random.rand(2, 5, 3, 4)
dim = 2 # 0 <= dim < len(a.shape))
ind = np.expand_dims(np.argmax(a, axis=dim), axis=dim)
# ========== np: groundtruth ==========
np_max = np.expand_dims(np.max(a, axis=dim), axis=dim)
# ========= torch: gather =========
torch_max = torch.gather(torch.tensor(a), dim=dim, index=torch.tensor(ind))
# ========= tensorflow: torch-like gather =========
tf_max = tf_gather(tf.convert_to_tensor(a), axis=dim, indices=tf.convert_to_tensor(ind))
keepdim = False
if not keepdim:
np_max = np.squeeze(np_max, axis=dim)
torch_max = torch.squeeze(torch_max, dim=dim)
tf_max = tf.squeeze(tf_max, axis=dim)
# print('np_max:\n', np_max)
# print('torch_max:\n', torch_max)
# print('tf_max:\n', tf_max)
assert np.allclose(np_max, torch_max.numpy()), '\33[1m\33[31mError with torch\33[0m'
assert np.allclose(np_max, tf_max.numpy()), '\33[1m\33[31mError with tensorflow\33[0m'
print('\33[1m\33[32mSuccess!\33[0m')
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.