简体   繁体   English

TensorFlow-类似于numpy的张量索引

[英]TensorFlow - numpy-like tensor indexing

In numpy, we can do this: 在numpy中,我们可以这样做:

x = np.random.random((10,10))
a = np.random.randint(0,10,5)
b = np.random.randint(0,10,5)
x[a,b] # gives 5 entries from x, indexed according to the corresponding entries in a and b

When I try something equivalent in TensorFlow: 当我在TensorFlow中尝试等效的东西时:

xt = tf.constant(x)
at = tf.constant(a)
bt = tf.constant(b)
xt[at,bt]

The last line gives a "Bad slice index tensor" exception. 最后一行给出了“错误的切片索引张量”异常。 It seems TensorFlow doesn't support indexing like numpy or Theano. TensorFlow似乎不支持numpy或Theano等索引。

Does anybody know if there is a TensorFlow way of doing this (indexing a tensor by arbitrary values). 是否有人知道是否有TensorFlow方式(通过任意值索引张量)。 I've seen the tf.nn.embedding part, but I'm not sure they can be used for this and even if they can, it's a huge workaround for something this straightforward. 我看过tf.nn.embedding部分,但是我不确定它们是否可以用于此目的,即使可以使用,对于这种简单的操作来说,这也是一个巨大的解决方法。

(Right now, I'm feeding the data from x as an input and doing the indexing in numpy but I hoped to put x inside TensorFlow to get higher efficiency) (现在,我正在将x的数据作为输入并在numpy中进行索引,但我希望将x放入TensorFlow中以获得更高的效率)

You can actually do that now with tf.gather_nd . 实际上,您现在可以使用tf.gather_nd Let's say you have a matrix m like the following: 假设您有一个矩阵m ,如下所示:

| 1 2 3 4 |
| 5 6 7 8 |

And you want to build a matrix r of size, let's say, 3x2, built from elements of m , like this: 您想要构建一个由m元素构建的大小为3x2的矩阵r ,如下所示:

| 3 6 |
| 2 7 |
| 5 3 |
| 1 1 |

Each element of r corresponds to a row and column of m , and you can have matrices rows and cols with these indices (zero-based, since we are programming, not doing math!): 的每个元素r对应的行和列m ,你可以有矩阵rowscols与这些指数(从零开始,因为我们是编程,而不是做数学!):

       | 0 1 |         | 2 1 |
rows = | 0 1 |  cols = | 1 2 |
       | 1 0 |         | 0 2 |
       | 0 0 |         | 0 0 |

Which you can stack into a 3-dimensional tensor like this: 您可以将其堆叠成3维张量,如下所示:

| | 0 2 | | 1 1 | |
| | 0 1 | | 1 2 | |
| | 1 0 | | 2 0 | |
| | 0 0 | | 0 0 | |

This way, you can get from m to r through rows and cols as follows: 通过这种方式,你可以从mr通过rowscols如下:

import numpy as np
import tensorflow as tf

m = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
rows = np.array([[0, 1], [0, 1], [1, 0], [0, 0]])
cols = np.array([[2, 1], [1, 2], [0, 2], [0, 0]])

x = tf.placeholder('float32', (None, None))
idx1 = tf.placeholder('int32', (None, None))
idx2 = tf.placeholder('int32', (None, None))
result = tf.gather_nd(x, tf.stack((idx1, idx2), -1))

with tf.Session() as sess:
    r = sess.run(result, feed_dict={
        x: m,
        idx1: rows,
        idx2: cols,
    })
print(r)

Output: 输出:

[[ 3.  6.]
 [ 2.  7.]
 [ 5.  3.]
 [ 1.  1.]]

LDGN's comment is correct. LDGN的评论是正确的。 This is not possible at the moment, and is a requested feature. 目前尚不可能,这是一项必需的功能。 If you follow issue#206 on github you'll get updated if/when this is available. 如果您在github上关注issue#206,则在有可用更新时会得到更新。 Many people would like this feature. 许多人都希望使用此功能。

For Tensorflow 0.11 , basic indexing has been implemented. 对于Tensorflow 0.11 ,已实现基本索引。 More advanced indexing (like boolean indexing) is still missing but apparently is planned for future versions. 仍然缺少更高级的索引编制(如布尔索引编制),但显然计划在将来的版本中使用。

Advanced indexing can be tracked with https://github.com/tensorflow/tensorflow/issues/4638 可以使用https://github.com/tensorflow/tensorflow/issues/4638跟踪高级索引

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM