[英]Tensorflow gather specific elements in 4d tensor
I have a 4D-tensor Params of dimensions [B, Y, X, N], and want to select a specific slice n ∈ N
from it, such that my resulting tensor is of size [B, Y, X, 1] (or [B, Y, X]).我有一个尺寸为 [B, Y, X, N] 的 4D 张量参数,并且想从中选择一个特定的切片
n ∈ N
,这样我得到的张量的大小为 [B, Y, X, 1] (或 [B, Y, X])。
The specific slice should be the one containing the highest numbers on average;特定切片应该是平均包含最高数字的切片; I obtain the indices like so:
我像这样获得指数:
indices = tf.argmax(tf.reduce_sum(Params, [1,2]), 1)
(shape [B]) indices = tf.argmax(tf.reduce_sum(Params, [1,2]), 1)
(形状 [B])
I tried different solutions using gather
or gather_nd
, but couldn't get it to work.我使用
gather
或gather_nd
尝试了不同的解决方案,但无法使其正常工作。 There are multiple posts very similar to this, yet I wasn't able to apply one of the solutions presented there.有多个帖子与此非常相似,但我无法应用那里提供的解决方案之一。
I'm running Tensorflow 1.3 so the fancy new axis-parameter for gather
is available.我正在运行 Tensorflow 1.3,因此可以使用花哨的新轴参数
gather
。
In the example code below the input is with shape [2,3,4,5]
and the resulting shape is [2,3,4]
.在下面的示例代码中,输入的形状为
[2,3,4,5]
,结果形状为[2,3,4]
。
The primary ideas are:主要思想是:
gather_nd
, so I switched the last two dimensions with tf.transpose
.gather_nd
,所以我换过去的两个维度与tf.transpose
。tf.argmax
( indices
below) to something really usable (see final_idx
below) in tf.gather_nd
.tf.argmax
( indices
下文),以真正可用(见东西final_idx
以下) tf.gather_nd
。 The conversion is done via stacking of three components:[0 0 0 1 1 1]
[0 1 2 0 1 2]
[3 3 3 0 0 0]
So we could go from [3, 0]
to所以我们可以从
[3, 0]
到
[[[0 0 3]
[0 1 3]
[0 2 3]]
[[1 0 0]
[1 1 0]
[1 2 0]]].
Batch,Y,X = 2, 3, 4
tf.reset_default_graph()
data = np.arange(Batch*Y*X*5)
np.random.shuffle(data)
Params = tf.constant(np.reshape(data, [Batch, Y, X, 5]), dtype=tf.int32)
indices = tf.argmax(tf.reduce_sum(Params, [1,2]), 1)
indices = tf.cast(tf.reshape(tf.tile(tf.reshape(indices, [-1,1]),
[1,Y]), [-1]), tf.int32)
idx = tf.reshape(tf.range(batch_size), [-1,1])
idx = tf.reshape(tf.tile(idx, [1, y]), [-1])
inc = tf.reshape(tf.tile(tf.range(Y), [Batch]), [-1])
final_idx = tf.reshape(tf.stack([idx, inc, indices], 1), [Batch, Y, -1])
transposed = tf.transpose(Params, [0, 1, 3, 2])
slice = tf.gather_nd(transposed, final_idx)
with tf.Session() as sess:
print sess.run(Params)
print sess.run(idx)
print sess.run(inc)
print sess.run(indices)
print sess.run(final_idx)
print sess.run(slice)
[[[[ 22 38 68 49 119]
[ 47 74 111 117 90]
[ 14 32 31 12 75]
[ 93 34 57 3 56]]
[[ 69 21 4 94 39]
[ 83 96 62 102 80]
[ 55 113 48 98 29]
[107 81 67 76 28]]
[[ 53 51 77 66 63]
[ 92 115 118 116 13]
[ 43 78 15 1 0]
[ 99 50 27 60 73]]]
[[[ 97 88 91 64 86]
[ 72 110 26 87 33]
[ 70 30 41 114 5]
[ 95 82 46 16 61]]
[[109 71 45 8 40]
[101 9 23 59 10]
[ 37 65 44 11 19]
[ 42 104 106 105 18]]
[[112 58 7 17 89]
[ 25 79 103 85 20]
[ 35 6 108 100 36]
[ 24 52 2 54 84]]]]
[0 0 0 1 1 1]
[0 1 2 0 1 2]
[3 3 3 0 0 0]
[[[0 0 3]
[0 1 3]
[0 2 3]]
[[1 0 0]
[1 1 0]
[1 2 0]]]
[[[ 49 117 12 3]
[ 94 102 98 76]
[ 66 116 1 60]]
[[ 97 72 70 95]
[109 101 37 42]
[112 25 35 24]]]
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.