[英]How to extract rows and columns from a 3D array in Tensorflow
我想对 TensorFlow 张量执行以下索引操作。 TensorFlow 中将b
和c
作为输出的等效操作应该是什么? 虽然tf.gather_nd
文档有几个例子,但我无法生成等效的indices
张量来获得这些结果。
import tensorflow as tf
import numpy as np
a=np.arange(18).reshape((2,3,3))
idx=[2,0,1] #it can be any validing re-ordering index list
#These are the two numpy operations that I want to do in Tensorflow
b=a[:,idx,:]
c=a[:,:,idx]
# TensorFlow operations
aT=tf.constant(a)
idxT=tf.constant(idx)
# what should be these two indices
idx1T=tf.reshape(idxT, (3,1))
idx2T=tf.reshape(idxT, (1,1,3))
bT=tf.gather_nd(aT, idx1T ) #does not work
cT=tf.gather_nd(aT, idx2T) #does not work
with tf.Session() as sess:
b1,c1=sess.run([bT,cT])
print(np.allclose(b,b1))
print(np.allclose(c,c1))
我不限于tf.gather_nd
在 GPU 上实现相同操作的任何其他建议都会有所帮助。
旧语句: c=a[:,idx]
,
新语句: c=a[:,:,idx]
我想要实现的也是重新排序列。
这可以通过tf.gather
完成,使用axis
参数:
import tensorflow as tf
import numpy as np
a = np.arange(18).reshape((2,3,3))
idx = [2,0,1]
b = a[:, idx, :]
c = a[:, :, idx]
aT = tf.constant(a)
idxT = tf.constant(idx)
bT = tf.gather(aT, idxT, axis=1)
cT = tf.gather(aT, idxT, axis=2)
with tf.Session() as sess:
b1, c1=sess.run([bT, cT])
print(np.allclose(b, b1))
print(np.allclose(c, c1))
输出:
True
True
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.