繁体   English   中英

如何在tensorflow中使用索引数组?

[英]How can I use the index array in tensorflow?

如果给定一个带有形状(5,3)的矩阵a和带有形状(5,)索引数组b ,我们可以很容易地得到相应的向量c

c = a[np.arange(5), b]

但是,我不能用tensorflow做同样的事情,

a = tf.placeholder(tf.float32, shape=(5, 3))
b = tf.placeholder(tf.int32, [5,])
# this line throws error
c = a[tf.range(5), b]

回溯(最近调用最后一次):文件“”,第1行,文件“〜/ anaconda2 / lib / python2.7 / site-packages / tensorflow / python / ops / array_ops.py”,第513行,在_SliceHelper name =名称)

文件“〜/ anaconda2 / lib / python2.7 / site-packages / tensorflow / python / ops / array_ops.py”,第671行,strided_slice中的shrink_axis_mask = shrink_axis_mask)文件“〜/ anaconda2 / lib / python2.7 / site- packages / tensorflow / python / ops / gen_array_ops.py“,第3688行,strided_slice shrink_axis_mask = shrink_axis_mask,name = name)文件”〜/ anaconda2 / lib / python2.7 / site-packages / tensorflow / python / framework / op_def_library。 py“,第763行,在apply_op中op_def = op_def)文件”〜/ anaconda2 / lib / python2.7 / site-packages / tensorflow / python / framework / ops.py“,第2397行,在create_op中set_shapes_for_outputs(ret)文件” 〜/ anaconda2 / lib / python2.7 / site-packages / tensorflow / python / framework / ops.py“,第1757行,在set_shapes_for_outputs中shape = shape_func(op)文件”〜/ anaconda2 / lib / python2.7 / site- packages / tensorflow / python / framework / ops.py“,第1707行,在call_with_requiring中返回call_cpp_shape_fn(op,require_shape_fn = True)文件”〜/ anaconda2 / lib / python2.7 / site-packages / tensorflow / python / framework / common_shapes .py“,第610行,在call_cpp_中 shape_fn debug_python_shape_fn,require_shape_fn)文件“〜/ anaconda2 / lib / python2.7 / site-packages / tensorflow / python / framework / common_shapes.py”,第675行,在_call_cpp_shape_fn_impl中引发ValueError(err.message)ValueError:Shape必须是rank 1,但是'strided_slice_14'(op:'StridedSlice')的排名为2,输入形状为:[5,3],[2,5],[2,5],[2]。

我的问题是,如果我不能使用上述方法在numpy中产生tensorflow的预期结果,我该怎么办?

此功能目前尚未在TensorFlow中实现。 GitHub 问题#4638正在跟踪NumPy风格的“高级”索引的实现。 但是,您可以使用tf.gather_nd()运算符来实现您的程序:

a = tf.placeholder(tf.float32, shape=(5, 3))
b = tf.placeholder(tf.int32, (5,))

row_indices = tf.range(5)

# `indices` is a 5 x 2 matrix of coordinates into `a`.
indices = tf.transpose([row_indices, b])

c = tf.gather_nd(a, indices)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM