簡體   English   中英

Tensorflow - 如何使用批量維度執行 tf.gather

[英]Tensorflow - How to perform tf.gather with batch dimension

不幸的是我不知道如何制定這個問題的標題,也許有人可以改變它?

如何以優雅的方式替換以下 for 循環?

#tensor.shape -> (batchsize,100)
#indices.shape -> (batchsize,100) 
liste = []
for i in range(tensor.shape[0]):
        liste.append(tf.gather(tensor[i,:], indices[i,:10]))

new_tensor = tf.stack(liste)
        

這應該可以解決問題:

new_tensor = tf.gather(tensor, axis=-1, indices=indices[:, :10], batch_dims=1)

這里有一個最小的可重現示例:

import tensorflow as tf

# for version 1.x
#tf.enable_eager_execution()

tensor = tf.random.normal((2, 10))
indices = tf.random.uniform(shape=[2, 10], minval=0, maxval=4, dtype=tf.int32)

liste = []
for i in range(tensor.shape[0]):
        liste.append(tf.gather(tensor[i,:], indices[i,:5]))

new_tensor = tf.stack(liste)

print('tensor: ')
print(tensor)

print('new_tensor: ')
print(new_tensor)

new_tensor_v2 = tf.gather(tensor, axis=-1, indices=indices[:, :5], batch_dims=1)
print('new_tensor_v2: ')
print(new_tensor_v2)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM