简体   繁体   English

批量 4D 张量 Tensorflow 索引

[英]Batched 4D tensor Tensorflow indexing

Given给定的

  • batch_images : 4D tensor of shape (B, H, W, C) batch_images :4D 形状张量(B, H, W, C)
  • x : 3D tensor of shape (B, H, W) x : 形状(B, H, W) 3D 张量
  • y : 3D tensor of shape (B, H, W) y : 形状(B, H, W) 3D 张量

Goal目标

How can I index into batch_images using the x and y coordinates to obtain a 4D tensor of shape B, H, W, C .如何使用xy坐标索引batch_images以获得形状为B, H, W, C的 4D 张量。 That is, I want to obtain for each batch, and for each pair (x, y) a tensor of shape C .也就是说,我想为每批和每对(x, y)获得形状为C的张量。

In numpy, this would be achieved using input_img[np.arange(B)[:,None,None], y, x] for example but I can't seem to make it work in tensorflow.例如,在 numpy 中,这可以使用input_img[np.arange(B)[:,None,None], y, x]来实现,但我似乎无法让它在 tensorflow 中工作。

My attempt so far我到目前为止的尝试

def get_pixel_value(img, x, y):
    """
    Utility function to get pixel value for 
    coordinate vectors x and y from a  4D tensor image.
    """
    H = tf.shape(img)[1]
    W = tf.shape(img)[2]
    C = tf.shape(img)[3]

    # flatten image
    img_flat = tf.reshape(img, [-1, C])

    # flatten idx
    idx_flat = (x*W) + y

    return tf.gather(img_flat, idx_flat)

which is returning an incorrect tensor of shape (B, H, W) .它返回形状不正确的张量(B, H, W)

It should be possible to do it by flattening the tensor as you've done, but the batch dimension has to be taken into account in the index calculation.应该可以通过像您所做的那样展平张量来做到这一点,但是在索引计算中必须考虑批量维度。 In order to do this, you'll have to make an additional dummy batch index tensor with the same shape as x and y that always contains the index of the current batch.为此,您必须制作一个与xy形状相同的附加虚拟批次索引张量,该张量始终包含当前批次的索引。 This is basically the np.arange(B) from your numpy example, which is missing from your TensorFlow code.这基本上是您的 numpy 示例中的np.arange(B) ,您的 TensorFlow 代码中缺少该示例。

You can also simplify things a bit by using tf.gather_nd , which does the index calculations for you.您还可以使用tf.gather_nd来简化一些事情,它会为您进行索引计算。

Here's an example:下面是一个例子:

import numpy as np
import tensorflow as tf

# Example tensors
M = np.random.uniform(size=(3, 4, 5, 6))
x = np.random.randint(0, 5, size=(3, 4, 5))
y = np.random.randint(0, 4, size=(3, 4, 5))

def get_pixel_value(img, x, y):
    """
    Utility function that composes a new image, with pixels taken
    from the coordinates given in x and y.
    The shapes of x and y have to match.
    The batch order is preserved.
    """

    # We assume that x and y have the same shape.
    shape = tf.shape(x)
    batch_size = shape[0]
    height = shape[1]
    width = shape[2]

    # Create a tensor that indexes into the same batch.
    # This is needed for gather_nd to work.
    batch_idx = tf.range(0, batch_size)
    batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1))
    b = tf.tile(batch_idx, (1, height, width))

    indices = tf.pack([b, y, x], 3)
    return tf.gather_nd(img, indices)

s = tf.Session()
print(s.run(get_pixel_value(M, x, y)).shape)
# Should print (3, 4, 5, 6).
# We've composed a new image of the same size from randomly picked x and y
# coordinates of each original image.

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM