简体   繁体   中英

Tensorflow multi-dimensional indexing by another tensor

Say, I have a rank-k tensor X of shape [n1, n2, ..., nk] and a rank-(k-1) tensor IDX of shape [n2, n3, ..., nk], where IDX has the same shape as the last (k-1) dimensions of X. The entries of IDX are all integers in [0, n1). I would like to fetch some values from X where the first dimension positions are specified by IDX while the other dimensions are iterated all through.

Example:

X = tf.constant([[1,2], [3,4], [5,6],
                 [7,8], [9,10],[11,12]]) # 2 x 3 x 2 tensor
IDX = tf.constant([[1,0], [1,1], [0,1]]) #     3 x 2 tensor
...
# would like to get [[7,2],[9,10],[5,12]]

How to achieve this in Tensorflow efficiently? Thanks!

Did you see the note for choose ?

Notes

To reduce the chance of misinterpretation, even though the following "abuse" is nominally supported, choices should neither be, nor be thought of as, a single array, ie, the outermost sequence-like container should be either a list or a tuple.

That is, they want you to treat it like:

In [432]: list(X)
Out[432]: [array([1, 2]), array([3, 4]), array([5, 6])]
In [433]: np.choose(IDX,list(X))
Out[433]: array([3, 6])

The indexing equivalent is:

In [436]: X[IDX,np.arange(2)]
Out[436]: array([3, 6])

choose also has some mode options.

The docs also say it's equivalent to (minus these mode issues):

np.choose(a,c) == np.array([c[a[I]][I] for I in ndi.ndindex(a.shape)])

Another nuance with choose . It can't work with more than 32 choices.

In [440]: np.choose(IDX,np.ones((33,2)))
...
ValueError: Need at least 1 and at most 32 array objects.

In [442]: np.ones((33,2))[IDX,np.arange(2)]
Out[442]: array([ 1.,  1.])

You can wrap np.choose() in a python function and embed it in your tensorflow graph with tf.py_func() . But you would also define gradient for your function if you would like automatic gradient computation of your graph for training to be available for you. Defining gradient for np.choose() might be very tricky task I suppose, if actually being solvable at all.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM