I have a tensor of shape (?, 3, 2, 5)
. I want to supply pairs of indices to select from the first and second dimensions of that tensor, that have shape (3, 2)
.
If I supply 4 such pairs, I would expect the resulting shape to be (?, 4, 5)
. I'd thought this is what what batch_gather
is for: to "broadcast" gathering indices over the first (batch) dimension. But this is not what it's doing:
import tensorflow as tf
data = tf.placeholder(tf.float32, (None, 3, 2, 5))
indices = tf.constant([
[2, 1],
[2, 0],
[1, 1],
[0, 1]
], tf.int32)
tf.batch_gather(data, indices)
Which results in <tf.Tensor 'Reshape_3:0' shape=(4, 2, 2, 5) dtype=float32>
instead of the shape that I was expecting.
How can I do what I want without explicitly indexing the batches (which have an unknown size)?
Using tf.batch_gather
the leading dimensions of the shape of the tensor
should match with the leading dimension of the shape of the indice
tensor.
import tensorflow as tf
data = tf.placeholder(tf.float32, (2, 3, 2, 5))
print(data.shape) // (2, 3, 2, 5)
# shape of indices, [2, 3]
indices = tf.constant([
[1, 1, 1],
[0, 0, 1]
])
print(tf.batch_gather(data, indices).shape) # (2, 3, 2, 5)
# if shape of indice was (2, 3, 1) the output would be 2, 3, 1, 5
What you rather want is to use tf.gather_nd
as the following
data_transpose = tf.transpose(data, perm=[2, 1, 0, 3])
t_transpose = tf.gather_nd(data_transpose, indices)
t = tf.transpose(t_transpose, perm=[1, 0, 2])
print(t.shape) # (?, 4, 5)
I wanted to avoid transpose
and Python loops, and I think this works. This was the setup:
import numpy as np
import tensorflow as tf
shape = None, 3, 2, 5
data = tf.placeholder(tf.int32, shape)
idxs_list = [
[2, 1],
[2, 0],
[1, 1],
[0, 1]
]
idxs = tf.constant(idxs_list, tf.int32)
This allows us to gather the results:
batch_size, num_idxs, num_channels = tf.shape(data)[0], tf.shape(idxs)[0], shape[-1]
batch_idxs = tf.math.floordiv(tf.range(0, batch_size * num_idxs), num_idxs)[:, None]
nd_idxs = tf.concat([batch_idxs, tf.tile(idxs, (batch_size, 1))], axis=1)
gathered = tf.reshape(tf.gather_nd(data, nd_idxs), (batch_size, num_idxs, num_channels))
When we run with a batch size of 4
, we get a result with shape (4, 4, 5)
, which is (batch_size, num_idxs, num_channels)
.
vals_shape = 4, *shape[1:]
vals = np.arange(int(np.prod(vals_shape))).reshape(vals_shape)
with tf.Session() as sess:
result = gathered.eval(feed_dict={data: vals})
Which ties out with numpy
indexing:
x, y = zip(*idxs_list)
assert np.array_equal(result, vals[:, x, y])
Essentially, gather_nd
wants batch indices in the first dimension, and those have to be repeated once for each index pair (ie, [0, 0, 0, 0, 1, 1, 1, 1, 2, ...]
if there are 4 index pairs).
Since there doesn't seem to be a tf.repeat
, I used range
and floordiv
, and then concat
ed the batch indices with the desired (x, y) indices (which are themselves tiled batch_size
times).
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.