[英]Tensorflow - How to perform tf.gather with batch dimension
Unfortunately I dont know how to formulate the title of this questions, maybe someone can change it?不幸的是我不知道如何制定这个问题的标题,也许有人可以改变它?
How to replace the following for loop in an elegant way?如何以优雅的方式替换以下 for 循环?
#tensor.shape -> (batchsize,100)
#indices.shape -> (batchsize,100)
liste = []
for i in range(tensor.shape[0]):
liste.append(tf.gather(tensor[i,:], indices[i,:10]))
new_tensor = tf.stack(liste)
This should do the trick:这应该可以解决问题:
new_tensor = tf.gather(tensor, axis=-1, indices=indices[:, :10], batch_dims=1)
Here with a minimal reproducible example:这里有一个最小的可重现示例:
import tensorflow as tf
# for version 1.x
#tf.enable_eager_execution()
tensor = tf.random.normal((2, 10))
indices = tf.random.uniform(shape=[2, 10], minval=0, maxval=4, dtype=tf.int32)
liste = []
for i in range(tensor.shape[0]):
liste.append(tf.gather(tensor[i,:], indices[i,:5]))
new_tensor = tf.stack(liste)
print('tensor: ')
print(tensor)
print('new_tensor: ')
print(new_tensor)
new_tensor_v2 = tf.gather(tensor, axis=-1, indices=indices[:, :5], batch_dims=1)
print('new_tensor_v2: ')
print(new_tensor_v2)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.