简体   繁体   中英

Indexing k-th dimension of tensor with another tensor in Tensorflow 2.0

I have a tensor probs which has shape (None, None, 110) representing (batch_size, sequence_length, 110) in an LSTM. I have another tensor indices which has shape (None, None) , which contains the indices of the elements to select from the third dimension of probs .

I want to use indices to index the tensor probs .

Numpy equivalent:

k, j = np.meshgrid(np.arange(probs.shape[1]), np.arange(probs.shape[0]))
indexed_probs = probs[j, k, indices]

Since shape[0] and shape[1] of probs is not known, tf.meshgrid() is not an option. I found tf.gather , tf.gather_nd and tf.batch_gather , but they all don't seem to do what I want.

Does anybody know how to do this?

You can do that with tf.gather_nd like this:

indexed_probs = tf.gather_nd(probs, tf.expand_dims(indices, axis=-1), batch_dims=2)

By the way, in NumPy you can use np.take_along_axis to do the same:

indexed_probs = np.take_along_axis(probs, np.expand_dims(indices, axis=-1), axis=-1)[..., 0]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM