[英]Tensorflow: How to index one tensor using another tensor with unknown dimensions?
I have a tensor params
with shape (?, 70, 64)
and another tensor indices
with shape (?, 1)
.我有一个形状为
(?, 70, 64)
的张量params
和另一个形状为(?, 1)
张量indices
。 I want to index into the first tensor's axis 1 using the second tensor, to get a result with shape (?, 64)
.我想使用第二个张量索引第一个张量的轴 1,以获得形状为
(?, 64)
。
I can't figure how to go about it.我不知道该怎么做。 Here's what I've tried:
这是我尝试过的:
tf.gather(params, indices) # returns a tensor of shape (?, 1, 70, 64)
tf.gather(params, indices, axis=1) # returns a tensor of shape (?, ?, 1, 64)
tf.gather_nd(params, indices) # returns a tensor of shape (?, 70, 64)
(I have an older version of TensorFlow, which doesn't have batch_gather
. ) Any help would be appreciated. (我有一个旧版本的 TensorFlow,它没有
batch_gather
。)任何帮助将不胜感激。
Thanks!谢谢!
You can use tf.stack
to convert your indices to a tensor of shape (?, 2)
with the first number in the second dimension being the batch number.您可以使用
tf.stack
将您的索引转换为形状为(?, 2)
的张量,其中第二维中的第一个数字是批号。 Then using this new indices with tf.gather_nd
should give you what you want if I understand your goal correctly.然后,如果我正确理解您的目标,将这个新索引与
tf.gather_nd
一起tf.gather_nd
应该会给您想要的。
Since your indices
is a tensor of shape (?, 1)
, batch_gather
would give you (?, 1, 64)
, meaning one reshape step from your expected result tensor of shape (?, 64)
.由于您的
indices
是形状(?, 1)
的张量, batch_gather
会给您(?, 1, 64)
,这意味着从形状(?, 64)
预期结果张量的一个重塑步骤。 The following code shows two methods give you the same result:下面的代码显示了两种方法给你相同的结果:
import numpy as np
import tensorflow as tf
params = tf.constant(np.arange(3*70*64).reshape(3, 70, 64))
init_indices = tf.constant([[2], [1], [0]])
indices = tf.stack(
[tf.range(init_indices.shape[0]), tf.reshape(init_indices, [-1])],
axis=1
)
output = tf.gather_nd(params, indices)
batch_gather = tf.reshape(tf.batch_gather(params, init_indices),
[params.shape[0], -1])
with tf.Session() as sess:
print('tf.gather_nd')
print(output.shape)
print(sess.run(output))
print('batch_gather')
print(batch_gather.shape)
print(sess.run(batch_gather))
Overall, the optimal solution depends on the specific use case, and to use tf.gather_nd
with tf.stack
, the key is to get the batch size, ie the first dimension.总的来说,最佳解决方案取决于具体用例,并且将
tf.gather_nd
与tf.stack
一起使用,关键是获得批量大小,即第一维。 One way, which again may not be optimal, is to use tf.shape
:一种可能不是最佳的方法是使用
tf.shape
:
import numpy as np
import tensorflow as tf
params = tf.placeholder(shape=(None, 70, 64), dtype=tf.int32)
init_indices = tf.placeholder(shape=(None, 1), dtype=tf.int32)
indices = tf.stack(
[tf.range(tf.shape(init_indices)[0]), tf.reshape(init_indices, [-1])],
axis=1
)
output = tf.gather_nd(params, indices)
batch_gather = tf.reshape(tf.batch_gather(params, init_indices),
[tf.shape(params)[0], -1])
with tf.Session() as sess:
print('tf.gather_nd')
print(output.shape)
print(sess.run(
output, feed_dict={params: np.arange(3*70*64).reshape(3, 70, 64),
init_indices: [[2], [1], [0]]}
))
print('batch_gather')
print(batch_gather.shape)
print(sess.run(
batch_gather, feed_dict={params: np.arange(3*70*64).reshape(3, 70, 64),
init_indices: [[2], [1], [0]]}
))
One thing to point out is because batch size is unknown, print(batch_gather.shape)
gives (?, ?)
rather than (?, 64)
.需要指出的一件事是因为批量大小未知,所以
print(batch_gather.shape)
给出(?, ?)
而不是(?, 64)
。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.