简体   繁体   English

Tensorflow 在 4d 张量中收集特定元素

[英]Tensorflow gather specific elements in 4d tensor

I have a 4D-tensor Params of dimensions [B, Y, X, N], and want to select a specific slice n ∈ N from it, such that my resulting tensor is of size [B, Y, X, 1] (or [B, Y, X]).我有一个尺寸为 [B, Y, X, N] 的 4D 张量参数,并且想从中选择一个特定的切片n ∈ N ,这样我得到的张量的大小为 [B, Y, X, 1] (或 [B, Y, X])。

The specific slice should be the one containing the highest numbers on average;特定切片应该是平均包含最高数字的切片; I obtain the indices like so:我像这样获得指数:
indices = tf.argmax(tf.reduce_sum(Params, [1,2]), 1) (shape [B]) indices = tf.argmax(tf.reduce_sum(Params, [1,2]), 1) (形状 [B])

I tried different solutions using gather or gather_nd , but couldn't get it to work.我使用gathergather_nd尝试了不同的解决方案,但无法使其正常工作。 There are multiple posts very similar to this, yet I wasn't able to apply one of the solutions presented there.有多个帖子与此非常相似,但我无法应用那里提供的解决方案之一。

I'm running Tensorflow 1.3 so the fancy new axis-parameter for gather is available.我正在运行 Tensorflow 1.3,因此可以使用花哨的新轴参数gather

In the example code below the input is with shape [2,3,4,5] and the resulting shape is [2,3,4] .在下面的示例代码中,输入的形状为[2,3,4,5] ,结果形状为[2,3,4]

The primary ideas are:主要思想是:

  • It's easy to get a row instead of a column using gather_nd , so I switched the last two dimensions with tf.transpose .这很容易得到一个行,而不是使用列gather_nd ,所以我换过去的两个维度与tf.transpose
  • We need to convert the indices we get from tf.argmax ( indices below) to something really usable (see final_idx below) in tf.gather_nd .我们需要我们得到的指数从转换tf.argmaxindices下文),以真正可用(见东西final_idx以下) tf.gather_nd The conversion is done via stacking of three components:转换是通过三个组件的堆叠完成的:
    • [0 0 0 1 1 1]
    • [0 1 2 0 1 2]
    • [3 3 3 0 0 0]

So we could go from [3, 0] to所以我们可以从[3, 0]

 [[[0 0 3]
   [0 1 3]
   [0 2 3]]
  [[1 0 0]
   [1 1 0]
   [1 2 0]]].
Batch,Y,X = 2, 3, 4
tf.reset_default_graph()

data = np.arange(Batch*Y*X*5)
np.random.shuffle(data)

Params = tf.constant(np.reshape(data, [Batch, Y, X, 5]), dtype=tf.int32)
indices = tf.argmax(tf.reduce_sum(Params, [1,2]), 1)
indices = tf.cast(tf.reshape(tf.tile(tf.reshape(indices, [-1,1]), 
                                     [1,Y]), [-1]), tf.int32)

idx = tf.reshape(tf.range(batch_size), [-1,1])
idx = tf.reshape(tf.tile(idx, [1, y]), [-1])

inc = tf.reshape(tf.tile(tf.range(Y), [Batch]), [-1])
final_idx = tf.reshape(tf.stack([idx, inc, indices], 1), [Batch, Y, -1])
transposed = tf.transpose(Params, [0, 1, 3, 2])
slice = tf.gather_nd(transposed, final_idx)

with tf.Session() as sess:
    print sess.run(Params)
    print sess.run(idx)    
    print sess.run(inc)
    print sess.run(indices)
    print sess.run(final_idx)
    print sess.run(slice)
[[[[ 22  38  68  49 119]
   [ 47  74 111 117  90]
   [ 14  32  31  12  75]
   [ 93  34  57   3  56]]

  [[ 69  21   4  94  39]
   [ 83  96  62 102  80]
   [ 55 113  48  98  29]
   [107  81  67  76  28]]

  [[ 53  51  77  66  63]
   [ 92 115 118 116  13]
   [ 43  78  15   1   0]
   [ 99  50  27  60  73]]]


 [[[ 97  88  91  64  86]
   [ 72 110  26  87  33]
   [ 70  30  41 114   5]
   [ 95  82  46  16  61]]

  [[109  71  45   8  40]
   [101   9  23  59  10]
   [ 37  65  44  11  19]
   [ 42 104 106 105  18]]

  [[112  58   7  17  89]
   [ 25  79 103  85  20]
   [ 35   6 108 100  36]
   [ 24  52   2  54  84]]]]

[0 0 0 1 1 1]
[0 1 2 0 1 2]
[3 3 3 0 0 0]

[[[0 0 3]
  [0 1 3]
  [0 2 3]]

 [[1 0 0]
  [1 1 0]
  [1 2 0]]]

[[[ 49 117  12   3]
  [ 94 102  98  76]
  [ 66 116   1  60]]

 [[ 97  72  70  95]
  [109 101  37  42]
  [112  25  35  24]]]

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM