I am trying to slice a tensor with a indices tensor. For this purpose I am trying to use tf.gather
. However, I am having a hard time understanding the documentation and don't get it to work as I would expect it to:
I have two tensors. An activations
tensor with a shape of [1,240,4]
and an ids
tensor with the shape [1,1,120]
. I want to slice the second dimension of the activations
tensor with the indices provided in the third dimension of the ids
tensor:
downsampled_activations = tf.gather(activations, ids, axis=1)
I have given it the axis=1
option since that is the axis in the activations
tensor I want to slice.
However, this does not render the expected result and only gives me the following error:
tensorflow.python.framework.errors_impl.InvalidArgumentError: indices[0,0,1] = 1 is not in [0, 1)
I have tried various combinations of the axis
and batch_dims
options, but to no avail so far and the documentation doesn't really help me on my path. Anybody care to explain the parameters in more detail or on the example above would be very helpful!
Edit: The IDs are precomputed before runtime and come in through an input pipeline as such:
features = tf.io.parse_single_example(
serialized_example,
features={ 'featureIDs': tf.io.FixedLenFeature([], tf.string)}
They are then reshaped into the previous format:
feature_ids_raw = tf.decode_raw(features['featureIDs'], tf.int32)
feature_ids_shape = tf.stack([batch_size, (num_neighbours * 4)])
feature_ids = tf.reshape(feature_ids_raw, feature_ids_shape)
feature_ids = tf.expand_dims(feature_ids, 0)
Afterwards they have the previously mentioned shape ( batch_size = 1
and num_neighbours = 30
-> [1,1,120]
) and I want to use them to slice the activations
tensor.
Edit2: I would like the output to be [1,120,4]
. (So I would like to gather the entries along the second dimension of the activations
tensor in accordance with the IDs stored in my ids
tensor.)
You can use :
downsampled_activations =tf.gather(activations , tf.squeeze(ids) ,axis = 1)
downsampled_activations.shape # [1,120,4]
In most cases, the tf.gather method needs 1d indices, and that is right in your case, instead of indices with 3d (1,1,120), a 1d is sufficient (120,). The method tf.gather will look at the axis( = 1) and return the element at each index provided by the indices tensor.
tf.gather
Gather slices fromparams
axisaxis
according to indices.
Granted that the documentation is not the most expressive, and the emphasis should be placed on the slices (since you index slices from the axis
and not elements, which is what I suppose you mistakenly took it for).
Let's take a much smaller example:
activations_small = tf.convert_to_tensor([[[1, 2, 3, 4], [11, 22, 33, 44]]])
print(activations_small.shape) # [1, 2, 4]
Let's picture this tensor:
XX 4 XX 44 XX XX
XX 3 XX 33 X XX
XXX 2 XX 22XX XX
X-----X-----+X XX
| 1 | 11 | XX
+-----+-----+X
tf.gather(activations1, [0, 0], axis=1)
will return
<tf.Tensor: shape=(1, 2, 4), dtype=int32, numpy=
array([[[1, 2, 3, 4],
[1, 2, 3, 4]]], dtype=int32)>
What tf.gather
did was to look from axis 1, and picks up index 0 (ofc, two times ie [0, 0]
). If you were to run tf.gather(activations1, [0, 0, 0, 0, 0], axis=1).shape
, you'd get TensorShape([1, 5, 4])
.
Your Error Now let's try to trigger the error that you're getting.
tf.gather(activations1, [0, 2], axis=1)
InvalidArgumentError: indices[1] = 2 is not in [0, 2) [Op:GatherV2]
What happened here was that when tf.gather
looks from axis 1 perspective, there's no item (column if you will) with index = 2.
I guess this is what the documentation is hinting at by
param:<indices>
The index Tensor. Must be one of the following types: int32, int64. Must be in range [0, params.shape[axis]).
Your (potential) solution
From the dimensions of indices
, and that of the expected result from your question, I am not sure if the above was very obvious to you.
tf.gather(activations, indices=[0, 1, 2, 3], axis=2)
or anything with indices within the range of indices in [0, activations.shape[2])
ie [0, 4)
would work. Anything else would give you the error that you're getting.
There's a verbatim answer below in case that's your expected result.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.