繁体   English   中英

如何从 Tensorflow 中的 3D 数组中提取行和列

[英]How to extract rows and columns from a 3D array in Tensorflow

我想对 TensorFlow 张量执行以下索引操作。 TensorFlow 中将bc作为输出的等效操作应该是什么? 虽然tf.gather_nd文档有几个例子,但我无法生成等效的indices张量来获得这些结果。

import tensorflow as tf
import numpy as np

a=np.arange(18).reshape((2,3,3))

idx=[2,0,1] #it can be any validing re-ordering index list

#These are the two numpy operations that I want to do in Tensorflow
b=a[:,idx,:]
c=a[:,:,idx] 

# TensorFlow operations

aT=tf.constant(a)
idxT=tf.constant(idx)

# what should be these two indices  
idx1T=tf.reshape(idxT, (3,1)) 
idx2T=tf.reshape(idxT, (1,1,3))

bT=tf.gather_nd(aT, idx1T ) #does not work
cT=tf.gather_nd(aT, idx2T)  #does not work

with tf.Session() as sess:
    b1,c1=sess.run([bT,cT])

print(np.allclose(b,b1))
print(np.allclose(c,c1))

我不限于tf.gather_nd在 GPU 上实现相同操作的任何其他建议都会有所帮助。

编辑:我已经更新了拼写错误的问题:

旧语句: c=a[:,idx]

新语句: c=a[:,:,idx]我想要实现的也是重新排序列。

这可以通过tf.gather完成,使用axis参数:

import tensorflow as tf
import numpy as np

a = np.arange(18).reshape((2,3,3))
idx = [2,0,1]
b = a[:, idx, :]
c = a[:, :, idx]

aT = tf.constant(a)
idxT = tf.constant(idx)
bT = tf.gather(aT, idxT, axis=1)
cT = tf.gather(aT, idxT, axis=2)

with tf.Session() as sess:
    b1, c1=sess.run([bT, cT])

print(np.allclose(b, b1))
print(np.allclose(c, c1))

输出:

True
True

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM