简体   繁体   中英

How to gather a tensor with unknown first (batch) dimension?

I have a tensor of shape (?, 3, 2, 5) . I want to supply pairs of indices to select from the first and second dimensions of that tensor, that have shape (3, 2) .

If I supply 4 such pairs, I would expect the resulting shape to be (?, 4, 5) . I'd thought this is what what batch_gather is for: to "broadcast" gathering indices over the first (batch) dimension. But this is not what it's doing:

import tensorflow as tf
data = tf.placeholder(tf.float32, (None, 3, 2, 5))


indices = tf.constant([
    [2, 1],
    [2, 0],
    [1, 1],
    [0, 1]
], tf.int32)

tf.batch_gather(data, indices)

Which results in <tf.Tensor 'Reshape_3:0' shape=(4, 2, 2, 5) dtype=float32> instead of the shape that I was expecting.

How can I do what I want without explicitly indexing the batches (which have an unknown size)?

Using tf.batch_gather the leading dimensions of the shape of the tensor should match with the leading dimension of the shape of the indice tensor.

import tensorflow as tf
data = tf.placeholder(tf.float32, (2, 3, 2, 5))
print(data.shape) // (2, 3, 2, 5)

# shape of indices, [2, 3]
indices = tf.constant([
    [1, 1, 1],
    [0, 0, 1]
])  
print(tf.batch_gather(data, indices).shape) # (2, 3, 2, 5)
# if shape of indice was (2, 3, 1) the output would be 2, 3, 1, 5

What you rather want is to use tf.gather_nd as the following

data_transpose = tf.transpose(data, perm=[2, 1, 0, 3])
t_transpose = tf.gather_nd(data_transpose, indices)
t = tf.transpose(t_transpose, perm=[1, 0, 2])
print(t.shape) # (?, 4, 5)

I wanted to avoid transpose and Python loops, and I think this works. This was the setup:

import numpy as np
import tensorflow as tf

shape = None, 3, 2, 5
data = tf.placeholder(tf.int32, shape)
idxs_list = [
    [2, 1],
    [2, 0],
    [1, 1],
    [0, 1]
]
idxs = tf.constant(idxs_list, tf.int32)

This allows us to gather the results:

batch_size, num_idxs, num_channels = tf.shape(data)[0], tf.shape(idxs)[0], shape[-1]

batch_idxs = tf.math.floordiv(tf.range(0, batch_size * num_idxs), num_idxs)[:, None]
nd_idxs = tf.concat([batch_idxs, tf.tile(idxs, (batch_size, 1))], axis=1)

gathered = tf.reshape(tf.gather_nd(data, nd_idxs), (batch_size, num_idxs, num_channels))

When we run with a batch size of 4 , we get a result with shape (4, 4, 5) , which is (batch_size, num_idxs, num_channels) .

vals_shape = 4, *shape[1:]
vals = np.arange(int(np.prod(vals_shape))).reshape(vals_shape)

with tf.Session() as sess:
    result = gathered.eval(feed_dict={data: vals})

Which ties out with numpy indexing:

x, y = zip(*idxs_list)
assert np.array_equal(result, vals[:, x, y])

Essentially, gather_nd wants batch indices in the first dimension, and those have to be repeated once for each index pair (ie, [0, 0, 0, 0, 1, 1, 1, 1, 2, ...] if there are 4 index pairs).

Since there doesn't seem to be a tf.repeat , I used range and floordiv , and then concat ed the batch indices with the desired (x, y) indices (which are themselves tiled batch_size times).

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM