[英]Batched 4D tensor Tensorflow indexing
Given给定的
batch_images
: 4D tensor of shape (B, H, W, C)
batch_images
:4D 形状张量(B, H, W, C)
x
: 3D tensor of shape (B, H, W)
x
: 形状(B, H, W)
3D 张量y
: 3D tensor of shape (B, H, W)
y
: 形状(B, H, W)
3D 张量Goal目标
How can I index into batch_images
using the x
and y
coordinates to obtain a 4D tensor of shape B, H, W, C
.如何使用
x
和y
坐标索引batch_images
以获得形状为B, H, W, C
的 4D 张量。 That is, I want to obtain for each batch, and for each pair (x, y)
a tensor of shape C
.也就是说,我想为每批和每对
(x, y)
获得形状为C
的张量。
In numpy, this would be achieved using input_img[np.arange(B)[:,None,None], y, x]
for example but I can't seem to make it work in tensorflow.例如,在 numpy 中,这可以使用
input_img[np.arange(B)[:,None,None], y, x]
来实现,但我似乎无法让它在 tensorflow 中工作。
My attempt so far我到目前为止的尝试
def get_pixel_value(img, x, y):
"""
Utility function to get pixel value for
coordinate vectors x and y from a 4D tensor image.
"""
H = tf.shape(img)[1]
W = tf.shape(img)[2]
C = tf.shape(img)[3]
# flatten image
img_flat = tf.reshape(img, [-1, C])
# flatten idx
idx_flat = (x*W) + y
return tf.gather(img_flat, idx_flat)
which is returning an incorrect tensor of shape (B, H, W)
.它返回形状不正确的张量
(B, H, W)
。
It should be possible to do it by flattening the tensor as you've done, but the batch dimension has to be taken into account in the index calculation.应该可以通过像您所做的那样展平张量来做到这一点,但是在索引计算中必须考虑批量维度。 In order to do this, you'll have to make an additional dummy batch index tensor with the same shape as
x
and y
that always contains the index of the current batch.为此,您必须制作一个与
x
和y
形状相同的附加虚拟批次索引张量,该张量始终包含当前批次的索引。 This is basically the np.arange(B)
from your numpy example, which is missing from your TensorFlow code.这基本上是您的 numpy 示例中的
np.arange(B)
,您的 TensorFlow 代码中缺少该示例。
You can also simplify things a bit by using tf.gather_nd
, which does the index calculations for you.您还可以使用
tf.gather_nd
来简化一些事情,它会为您进行索引计算。
Here's an example:下面是一个例子:
import numpy as np
import tensorflow as tf
# Example tensors
M = np.random.uniform(size=(3, 4, 5, 6))
x = np.random.randint(0, 5, size=(3, 4, 5))
y = np.random.randint(0, 4, size=(3, 4, 5))
def get_pixel_value(img, x, y):
"""
Utility function that composes a new image, with pixels taken
from the coordinates given in x and y.
The shapes of x and y have to match.
The batch order is preserved.
"""
# We assume that x and y have the same shape.
shape = tf.shape(x)
batch_size = shape[0]
height = shape[1]
width = shape[2]
# Create a tensor that indexes into the same batch.
# This is needed for gather_nd to work.
batch_idx = tf.range(0, batch_size)
batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1))
b = tf.tile(batch_idx, (1, height, width))
indices = tf.pack([b, y, x], 3)
return tf.gather_nd(img, indices)
s = tf.Session()
print(s.run(get_pixel_value(M, x, y)).shape)
# Should print (3, 4, 5, 6).
# We've composed a new image of the same size from randomly picked x and y
# coordinates of each original image.
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.