简体   繁体   中英

Indexing tensor in tensorflow using a tensor of integers

My question is similar to the one here but not exactly the same. I have two tensors

mu: (shape=(1000,1), dtype=np.float32)
p : (shape=(100,30), dtype=np.int64)

What I want is to create a new tensor

x : (shape=(100,30), dtype=np.float32)

such that

x[i,j] = mu[p[i,j]]

This can be done in numpy using advanced indexing

x = mu[p]

I have tried using the tf.gather_nd(mu, p) command but in my case I receive the following error

*** ValueError: indices.shape[-1] must be <= params.rank, but saw indices shape: [100,30] and params shape: [1000] for 'GatherNd_2' (op: 'GatherNd') with input shapes: [1000], [100,30].

It therefore seems in order to use this, I have to build a new tensor of coordinates. Is there a more simple way to accomplish what I want?

Here is a working solution:

tf.reshape(tf.gather(mu[:,0], tf.reshape(p, (-1,))), p.shape)

Basically it

  1. flattens the index array to 1d, tf.reshape(p, (-1,)) ;
  2. gather the elements from mu[:,0] (first column of mu );
  3. and then reshape it to p 's shape.

Minimal Example :

import tensorflow as tf
tf.InteractiveSession()

mu = tf.reshape(tf.multiply(tf.cast(tf.range(10), tf.float32), 0.1), (10, 1))
mu.eval()
#array([[ 0.        ],
#       [ 0.1       ],
#       [ 0.2       ],
#       [ 0.30000001],
#       [ 0.40000001],
#       [ 0.5       ],
#       [ 0.60000002],
#       [ 0.69999999],
#       [ 0.80000001],
#       [ 0.90000004]], dtype=float32)

p = tf.constant([[1,3],[2,4],[3,1]], dtype=tf.int64)

tf.reshape(tf.gather(mu[:,0], tf.reshape(p, (-1,))), p.shape).eval()

#array([[ 0.1       ,  0.30000001],
#       [ 0.2       ,  0.40000001],
#       [ 0.30000001,  0.1       ]], dtype=float32)

Another two options using gather_nd without reshaping:

tf.gather_nd(mu[:,0], tf.expand_dims(p, axis=-1)).eval()

#array([[ 0.1       ,  0.30000001],
#       [ 0.2       ,  0.40000001],
#       [ 0.30000001,  0.1       ]], dtype=float32)

tf.gather_nd(mu, tf.stack((p, tf.zeros_like(p)), axis=-1)).eval()

#array([[ 0.1       ,  0.30000001],
#       [ 0.2       ,  0.40000001],
#       [ 0.30000001,  0.1       ]], dtype=float32)

You can use tf.map_fn :

 x= tf.map_fn(lambda u: tf.gather(tf.squeeze(mu),u),p,dtype=mu.dtype)

map_fn acts as a loop that runs over the first dimensions of p , and for each such slice it applies tf.gather .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM