简体   繁体   中英

Slicing a tensor with a tensor of indices and tf.gather

I am trying to slice a tensor with a indices tensor. For this purpose I am trying to use tf.gather . However, I am having a hard time understanding the documentation and don't get it to work as I would expect it to:

I have two tensors. An activations tensor with a shape of [1,240,4] and an ids tensor with the shape [1,1,120] . I want to slice the second dimension of the activations tensor with the indices provided in the third dimension of the ids tensor:

downsampled_activations = tf.gather(activations, ids, axis=1)

I have given it the axis=1 option since that is the axis in the activations tensor I want to slice.

However, this does not render the expected result and only gives me the following error:

tensorflow.python.framework.errors_impl.InvalidArgumentError: indices[0,0,1] = 1 is not in [0, 1)

I have tried various combinations of the axis and batch_dims options, but to no avail so far and the documentation doesn't really help me on my path. Anybody care to explain the parameters in more detail or on the example above would be very helpful!

Edit: The IDs are precomputed before runtime and come in through an input pipeline as such:

features = tf.io.parse_single_example(
            serialized_example,
            features={ 'featureIDs': tf.io.FixedLenFeature([], tf.string)}

They are then reshaped into the previous format:

feature_ids_raw = tf.decode_raw(features['featureIDs'], tf.int32)
feature_ids_shape = tf.stack([batch_size, (num_neighbours * 4)])
feature_ids = tf.reshape(feature_ids_raw, feature_ids_shape)
feature_ids = tf.expand_dims(feature_ids, 0)

Afterwards they have the previously mentioned shape ( batch_size = 1 and num_neighbours = 30 -> [1,1,120] ) and I want to use them to slice the activations tensor.

Edit2: I would like the output to be [1,120,4] . (So I would like to gather the entries along the second dimension of the activations tensor in accordance with the IDs stored in my ids tensor.)

You can use :

downsampled_activations =tf.gather(activations , tf.squeeze(ids) ,axis = 1)
downsampled_activations.shape #  [1,120,4]

In most cases, the tf.gather method needs 1d indices, and that is right in your case, instead of indices with 3d (1,1,120), a 1d is sufficient (120,). The method tf.gather will look at the axis( = 1) and return the element at each index provided by the indices tensor.

tf.gather Gather slices from params axis axis according to indices.

Granted that the documentation is not the most expressive, and the emphasis should be placed on the slices (since you index slices from the axis and not elements, which is what I suppose you mistakenly took it for).

Let's take a much smaller example:

activations_small = tf.convert_to_tensor([[[1, 2, 3, 4], [11, 22, 33, 44]]])
print(activations_small.shape) # [1, 2, 4]

Let's picture this tensor:

    XX 4  XX 44 XX XX
  XX  3 XX  33 X  XX
XXX 2 XX   22XX  XX
X-----X-----+X  XX
|  1  |  11 | XX
+-----+-----+X

tf.gather(activations1, [0, 0], axis=1) will return

<tf.Tensor: shape=(1, 2, 4), dtype=int32, numpy=
array([[[1, 2, 3, 4],
        [1, 2, 3, 4]]], dtype=int32)>

What tf.gather did was to look from axis 1, and picks up index 0 (ofc, two times ie [0, 0] ). If you were to run tf.gather(activations1, [0, 0, 0, 0, 0], axis=1).shape , you'd get TensorShape([1, 5, 4]) .

Your Error Now let's try to trigger the error that you're getting.

tf.gather(activations1, [0, 2], axis=1)

InvalidArgumentError: indices[1] = 2 is not in [0, 2) [Op:GatherV2]

What happened here was that when tf.gather looks from axis 1 perspective, there's no item (column if you will) with index = 2.

I guess this is what the documentation is hinting at by

param:<indices> The index Tensor. Must be one of the following types: int32, int64. Must be in range [0, params.shape[axis]).

Your (potential) solution

From the dimensions of indices , and that of the expected result from your question, I am not sure if the above was very obvious to you.

tf.gather(activations, indices=[0, 1, 2, 3], axis=2) or anything with indices within the range of indices in [0, activations.shape[2]) ie [0, 4) would work. Anything else would give you the error that you're getting.

There's a verbatim answer below in case that's your expected result.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM