简体   繁体   中英

How can a tensor in tensorflow be sliced ​using elements of another array as an index?

I'm looking for a similar function to tf.unsorted_segment_sum, but I don't want to sum the segments, I want to get every segment as a tensor.

So for example, I have this code: (In real, I have a tensor with shapes of (10000, 63), and the number of segments would be 2500)

    to_be_sliced = tf.constant([[0.1, 0.2, 0.3, 0.4, 0.5],
                            [0.3, 0.2, 0.2, 0.6, 0.3],
                            [0.9, 0.8, 0.7, 0.6, 0.5],
                            [2.0, 2.0, 2.0, 2.0, 2.0]])

indices = tf.constant([0, 2, 0, 1])
num_segments = 3
tf.unsorted_segment_sum(to_be_sliced, indices, num_segments)

The output would be here

array([sum(row1+row3), row4, row2]

What I am looking for is 3 tensor with different shapes (maybe a list of tensors), first containing the first and third rows of the original (shape of (2, 5)), the second contains the 4th row (shape of (1, 5)), the third contains the second row, like this:

[array([[0.1, 0.2, 0.3, 0.4, 0.5],
        [0.9, 0.8, 0.7, 0.6, 0.5]]),
 array([[2.0, 2.0, 2.0, 2.0, 2.0]]),
 array([[0.3, 0.2, 0.2, 0.6, 0.3]])]

Thanks in advance!

For your case, you can do Numpy slicing in Tensorflow. So this will work:

sliced_1 = to_be_sliced[:3, :]
# [[0.4 0.5 0.5 0.7 0.8]
#  [0.3 0.2 0.2 0.6 0.3]
#  [0.3 0.2 0.2 0.6 0.3]]
sliced_2 = to_be_sliced[3, :]
# [0.3 0.2 0.2 0.6 0.3]

Or a more general option, you can do it in the following way:

to_be_sliced = tf.constant([[0.1, 0.2, 0.3, 0.4, 0.5],
                        [0.3, 0.2, 0.2, 0.6, 0.3],
                        [0.9, 0.8, 0.7, 0.6, 0.5],
                        [2.0, 2.0, 2.0, 2.0, 2.0]])

first_tensor = tf.gather_nd(to_be_sliced, [[0], [2]])
second_tensor = tf.gather_nd(to_be_sliced, [[3]])
third_tensor = tf.gather_nd(to_be_sliced, [[1]])

concat = tf.concat([first_tensor, second_tensor, third_tensor], axis=0)

You can do that like this:

import tensorflow as tf

to_be_sliced = tf.constant([[0.1, 0.2, 0.3, 0.4, 0.5],
                            [0.3, 0.2, 0.2, 0.6, 0.3],
                            [0.9, 0.8, 0.7, 0.6, 0.5],
                            [2.0, 2.0, 2.0, 2.0, 2.0]])
indices = tf.constant([0, 2, 0, 1])
num_segments = 3
result = [tf.boolean_mask(to_be_sliced, tf.equal(indices, i)) for i in range(num_segments)]
with tf.Session() as sess:
    print(*sess.run(result), sep='\n')

Output:

[[0.1 0.2 0.3 0.4 0.5]
 [0.9 0.8 0.7 0.6 0.5]]
[[2. 2. 2. 2. 2.]]
[[0.3 0.2 0.2 0.6 0.3]]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM