I am looking for the tensorflow equivalent of the following code in numpy. a
and idx_2
are given. The goal is to construct b
.
# A float Tensor obtained somehow
a = np.arange(3*5).reshape(3,5)
# An int Tensor obtained somehow
idx_2 = np.array([[1,2,3,4],[0,2,3,4],[0,2,3,4]])
# An int Tensor, constructed for indexing
idx_1 = np.arange(a.shape[0]).reshape(-1,1)
# The goal
b = a[idx_1, idx_2]
print(b)
>>> [[ 1 2 3 4]
[ 5 7 8 9]
[10 12 13 14]]
I have tried directly indexing the tensors and using tf.gather_nd
but I keep getting errors so I decided to ask how to do it here. Everywhere I look for answers people use tf.gather_nd
(hence the title) to solve similar problems, but to apply this functions I have to somehow reshape the indexes such that they can be used to slice the first dimension. How do I do this? Please help.
Tensorflow can be quite ugly when it comes to things that are very simple and Pythonic in NumPy. Here is how I used tf.gather_nd to recreate your problem in TensorFlow. There is probably a much better way to do it though.
import tensorflow as tf
import numpy as np
with tf.Session() as sess:
# Define 'a'
a = tf.reshape(tf.range(15),(3,5))
# Define both index tensors
idx_1 = tf.reshape(tf.range(a.get_shape().as_list()[0]),(-1,1)).eval()
idx_2 = tf.constant([[1,2,3,4],[0,2,3,4],[0,2,3,4]]).eval()
# get indices for use with gather_nd
gather_idx = tf.constant([(x[0],y) for (i,x) in enumerate(idx_1) for y in idx_2[i]])
# extract elements and reshape to desired dimensions
b = tf.gather_nd(a, gather_idx)
b = tf.reshape(b,(idx_1.shape[0], idx_2.shape[1]))
print(sess.run(b))
[[ 1 2 3 4]
[ 5 7 8 9]
[10 12 13 14]]
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.