简体   繁体   中英

In Keras with Tensorflow, how can I reindex an axis of nd Tensor?

Like this, except without the errors:

input = tf.convert_to_tensor(np.random.rand(500,100,5))
new_order = [0,4,1,3,2]
output = input[:,:,new_order]

Closest I found is tf.gather , but I haven't been able to make it work.

I find it easy to transpose the tensor so that the indexing dimension is the first dimension, gather the element and then put it back to the original shape.

output = tf.transpose(
    tf.gather(
        tf.transpose(input, [2,0,1]), 
        new_order
    ), [1,2,0]
)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM