I'm looking for a similar function to tf.unsorted_segment_sum, but I don't want to sum the segments, I want to get every segment as a tensor.
So for example, I have this code: (In real, I have a tensor with shapes of (10000, 63), and the number of segments would be 2500)
to_be_sliced = tf.constant([[0.1, 0.2, 0.3, 0.4, 0.5],
[0.3, 0.2, 0.2, 0.6, 0.3],
[0.9, 0.8, 0.7, 0.6, 0.5],
[2.0, 2.0, 2.0, 2.0, 2.0]])
indices = tf.constant([0, 2, 0, 1])
num_segments = 3
tf.unsorted_segment_sum(to_be_sliced, indices, num_segments)
The output would be here
array([sum(row1+row3), row4, row2]
What I am looking for is 3 tensor with different shapes (maybe a list of tensors), first containing the first and third rows of the original (shape of (2, 5)), the second contains the 4th row (shape of (1, 5)), the third contains the second row, like this:
[array([[0.1, 0.2, 0.3, 0.4, 0.5],
[0.9, 0.8, 0.7, 0.6, 0.5]]),
array([[2.0, 2.0, 2.0, 2.0, 2.0]]),
array([[0.3, 0.2, 0.2, 0.6, 0.3]])]
Thanks in advance!
For your case, you can do Numpy slicing in Tensorflow. So this will work:
sliced_1 = to_be_sliced[:3, :]
# [[0.4 0.5 0.5 0.7 0.8]
# [0.3 0.2 0.2 0.6 0.3]
# [0.3 0.2 0.2 0.6 0.3]]
sliced_2 = to_be_sliced[3, :]
# [0.3 0.2 0.2 0.6 0.3]
Or a more general option, you can do it in the following way:
to_be_sliced = tf.constant([[0.1, 0.2, 0.3, 0.4, 0.5],
[0.3, 0.2, 0.2, 0.6, 0.3],
[0.9, 0.8, 0.7, 0.6, 0.5],
[2.0, 2.0, 2.0, 2.0, 2.0]])
first_tensor = tf.gather_nd(to_be_sliced, [[0], [2]])
second_tensor = tf.gather_nd(to_be_sliced, [[3]])
third_tensor = tf.gather_nd(to_be_sliced, [[1]])
concat = tf.concat([first_tensor, second_tensor, third_tensor], axis=0)
You can do that like this:
import tensorflow as tf
to_be_sliced = tf.constant([[0.1, 0.2, 0.3, 0.4, 0.5],
[0.3, 0.2, 0.2, 0.6, 0.3],
[0.9, 0.8, 0.7, 0.6, 0.5],
[2.0, 2.0, 2.0, 2.0, 2.0]])
indices = tf.constant([0, 2, 0, 1])
num_segments = 3
result = [tf.boolean_mask(to_be_sliced, tf.equal(indices, i)) for i in range(num_segments)]
with tf.Session() as sess:
print(*sess.run(result), sep='\n')
Output:
[[0.1 0.2 0.3 0.4 0.5]
[0.9 0.8 0.7 0.6 0.5]]
[[2. 2. 2. 2. 2.]]
[[0.3 0.2 0.2 0.6 0.3]]
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.