I am trying to gather slices of a tensor in terms of the last dimension for partial connection between layers. Because the output tensor's shape is [batch_size, h, w, depth]
, I want to select slices based on the last dimension, such as
# L is intermediate tensor
partL = L[:, :, :, [0,2,3,8]]
However, tf.gather(L, [0, 2,3,8])
seems to only work for the first dimension (right?) Can anyone tell me how to do it?
As of TensorFlow 1.3 tf.gather
has an axis
parameter, so the various workarounds here are no longer necessary.
https://www.tensorflow.org/versions/r1.3/api_docs/python/tf/gather https://github.com/tensorflow/tensorflow/issues/11223
There's a tracking bug to support this use-case here: https://github.com/tensorflow/tensorflow/issues/206
For now you can:
transpose your matrix so that dimension to gather is first (transpose is expensive)
reshape your tensor into 1d (reshape is cheap) and turn your gather column indices into a list of individual element indices at linear indexing, then reshape back
gather_nd
. Will still need to turn your column indices into list of individual element indices. With gather_nd you can now do this as follows:
cat_idx = tf.concat([tf.range(0, tf.shape(x)[0]), indices_for_dim1], axis=0)
result = tf.gather_nd(matrix, cat_idx)
Also, as reported by user Nova in a thread referenced by @Yaroslav Bulatov's:
x = tf.constant([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
idx = tf.constant([1, 0, 2])
idx_flattened = tf.range(0, x.shape[0]) * x.shape[1] + idx
y = tf.gather(tf.reshape(x, [-1]), # flatten input
idx_flattened) # use flattened indices
with tf.Session(''):
print y.eval() # [2 4 9]
The gist is flatten the tensor and use strided 1D addressing with tf.gather(...).
Yet another solution using tf.unstack(...), tf.gather(...) and tf.stack(..)
Code:
import tensorflow as tf
import numpy as np
shape = [2, 2, 2, 10]
L = np.arange(np.prod(shape))
L = np.reshape(L, shape)
indices = [0, 2, 3, 8]
axis = -1 # last dimension
def gather_axis(params, indices, axis=0):
return tf.stack(tf.unstack(tf.gather(tf.unstack(params, axis=axis), indices)), axis=axis)
print(L)
with tf.Session() as sess:
partL = sess.run(gather_axis(L, indices, axis))
print(partL)
Result:
L =
[[[[ 0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]]
[[20 21 22 23 24 25 26 27 28 29]
[30 31 32 33 34 35 36 37 38 39]]]
[[[40 41 42 43 44 45 46 47 48 49]
[50 51 52 53 54 55 56 57 58 59]]
[[60 61 62 63 64 65 66 67 68 69]
[70 71 72 73 74 75 76 77 78 79]]]]
partL =
[[[[ 0 2 3 8]
[10 12 13 18]]
[[20 22 23 28]
[30 32 33 38]]]
[[[40 42 43 48]
[50 52 53 58]]
[[60 62 63 68]
[70 72 73 78]]]]
A correct version of @Andrei's answer would read
cat_idx = tf.stack([tf.range(0, tf.shape(x)[0]), indices_for_dim1], axis=1)
result = tf.gather_nd(matrix, cat_idx)
You can try this way, for instance(in most cases in NLP at the least),
The parameter is of shape [batch_size, depth]
and the indices are [i, j, k, n, m] of which the length is batch_size. Then gather_nd
can be helpful.
parameters = tf.constant([
[11, 12, 13],
[21, 22, 23],
[31, 32, 33],
[41, 42, 43]])
targets = tf.constant([2, 1, 0, 1])
batch_nums = tf.range(0, limit=parameters.get_shape().as_list()[0])
indices = tf.stack((batch_nums, targets), axis=1) # the axis is the dimension number
items = tf.gather_nd(parameters, indices)
# which is what we want: [13, 22, 31, 42]
This snippet first find the fist dimension through the batch_num and then fetch the item along that dimension by the target number.
Implementing 2. from @Yaroslav Bulatov's:
#Your indices
indices = [0, 2, 3, 8]
#Remember for final reshaping
n_indices = tf.shape(indices)[0]
flattened_L = tf.reshape(L, [-1])
#Walk strided over the flattened array
offset = tf.expand_dims(tf.range(0, tf.reduce_prod(tf.shape(L)), tf.shape(L)[-1]), 1)
flattened_indices = tf.reshape(tf.reshape(indices, [-1])+offset, [-1])
selected_rows = tf.gather(flattened_L, flattened_indices)
#Final reshape
partL = tf.reshape(selected_rows, tf.concat(0, [tf.shape(L)[:-1], [n_indices]]))
Credit to How to select rows from a 3-D Tensor in TensorFlow?
Tensor doesn't have attribute shape, but get_shape() method. Below is runnable by Python 2.7
import tensorflow as tf
import numpy as np
x = tf.constant([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
idx = tf.constant([1, 0, 2])
idx_flattened = tf.range(0, x.get_shape()[0]) * x.get_shape()[1] + idx
y = tf.gather(tf.reshape(x, [-1]), # flatten input
idx_flattened) # use flattened indices
with tf.Session(''):
print y.eval() # [2 4 9]
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.