简体   繁体   English

Tensorflow:如何使用另一个未知维度的张量索引一个张量?

[英]Tensorflow: How to index one tensor using another tensor with unknown dimensions?

I have a tensor params with shape (?, 70, 64) and another tensor indices with shape (?, 1) .我有一个形状为(?, 70, 64)的张量params和另一个形状为(?, 1)张量indices I want to index into the first tensor's axis 1 using the second tensor, to get a result with shape (?, 64) .我想使用第二个张量索引第一个张量的轴 1,以获得形状为(?, 64)

I can't figure how to go about it.我不知道该怎么做。 Here's what I've tried:这是我尝试过的:

tf.gather(params, indices)           # returns a tensor of shape (?, 1, 70, 64)
tf.gather(params, indices, axis=1)   # returns a tensor of shape (?, ?, 1, 64)
tf.gather_nd(params, indices)        # returns a tensor of shape (?, 70, 64)

(I have an older version of TensorFlow, which doesn't have batch_gather . ) Any help would be appreciated. (我有一个旧版本的 TensorFlow,它没有batch_gather 。)任何帮助将不胜感激。

Thanks!谢谢!

You can use tf.stack to convert your indices to a tensor of shape (?, 2) with the first number in the second dimension being the batch number.您可以使用tf.stack将您的索引转换为形状为(?, 2)的张量,其中第二维中的第一个数字是批号。 Then using this new indices with tf.gather_nd should give you what you want if I understand your goal correctly.然后,如果我正确理解您的目标,将这个新索引与tf.gather_nd一起tf.gather_nd应该会给您想要的。

Since your indices is a tensor of shape (?, 1) , batch_gather would give you (?, 1, 64) , meaning one reshape step from your expected result tensor of shape (?, 64) .由于您的indices是形状(?, 1)的张量, batch_gather会给您(?, 1, 64) ,这意味着从形状(?, 64)预期结果张量的一个重塑步骤。 The following code shows two methods give you the same result:下面的代码显示了两种方法给你相同的结果:

import numpy as np
import tensorflow as tf

params = tf.constant(np.arange(3*70*64).reshape(3, 70, 64))
init_indices = tf.constant([[2], [1], [0]])
indices = tf.stack(
    [tf.range(init_indices.shape[0]), tf.reshape(init_indices, [-1])],
    axis=1
)
output = tf.gather_nd(params, indices)
batch_gather = tf.reshape(tf.batch_gather(params, init_indices),
                          [params.shape[0], -1])

with tf.Session() as sess:
    print('tf.gather_nd')
    print(output.shape)
    print(sess.run(output))
    print('batch_gather')
    print(batch_gather.shape)
    print(sess.run(batch_gather))

Edit on comment "first dimension unknown"编辑评论“第一维未知”

Overall, the optimal solution depends on the specific use case, and to use tf.gather_nd with tf.stack , the key is to get the batch size, ie the first dimension.总的来说,最佳解决方案取决于具体用例,并且将tf.gather_ndtf.stack一起使用,关键是获得批量大小,即第一维。 One way, which again may not be optimal, is to use tf.shape :一种可能不是最佳的方法是使用tf.shape

import numpy as np
import tensorflow as tf

params = tf.placeholder(shape=(None, 70, 64), dtype=tf.int32)
init_indices = tf.placeholder(shape=(None, 1), dtype=tf.int32)
indices = tf.stack(
    [tf.range(tf.shape(init_indices)[0]), tf.reshape(init_indices, [-1])],
    axis=1
)
output = tf.gather_nd(params, indices)
batch_gather = tf.reshape(tf.batch_gather(params, init_indices),
                          [tf.shape(params)[0], -1])

with tf.Session() as sess:
    print('tf.gather_nd')
    print(output.shape)
    print(sess.run(
        output, feed_dict={params: np.arange(3*70*64).reshape(3, 70, 64),
                           init_indices: [[2], [1], [0]]}
    ))
    print('batch_gather')
    print(batch_gather.shape)
    print(sess.run(
        batch_gather, feed_dict={params: np.arange(3*70*64).reshape(3, 70, 64),
                                 init_indices: [[2], [1], [0]]}
    ))

One thing to point out is because batch size is unknown, print(batch_gather.shape) gives (?, ?) rather than (?, 64) .需要指出的一件事是因为批量大小未知,所以print(batch_gather.shape)给出(?, ?)而不是(?, 64)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM