简体   繁体   English

Tensorflow:从任意长度的复杂张量中提取连续的补丁

[英]Tensorflow: extract sequential patches from a complex tensor of arbitrary length

I'm trying to figure out how to extract sequential patches from a complex valued tensor where the length is variable.我试图弄清楚如何从长度可变的复值张量中提取连续的补丁。 The extraction is being performed as part of a tf.data pipeline.提取是作为tf.data管道的一部分执行的。

If the tensor were not complex, I'd use tf.image.extract_image_patches as in this answer .如果张量不复杂,我会像在这个答案中一样使用tf.image.extract_image_patches

However, that function does not work with complex tensors.但是,该函数不适用于复杂张量。 I have tried the following technique, but it fails because the length of the tensor is unknown.我尝试了以下技术,但它失败了,因为张量的长度未知。

def extract_sequential_patches(image):
    image_length = tf.shape(image)[0]
    num_patches = image_length // (128 // 4)
    patches = []
    for i in range(num_patches):
        start = i * 128
        end = start + 128
        patches.append(image[start:end, ...])
    return tf.stack(patches)

However I get the error:但是我收到错误:

InaccessibleTensorError: The tensor 'Tensor("strided_slice:0", shape=(None, 512, 2), dtype=complex64)' cannot be accessed here: it is defined in another function or code block. Use return values, explicit Python locals or TensorFlow collections to access it. Defined in: FuncGraph(name=while_body_2100, id=140313967335120)

I have tried liberal decoration with @tf.function我尝试过用@tf.function自由装饰

I think you'll need to adjust the calculation of the indices to make sure they don't go out of bounds, but leaving that detail aside, your code is almost what tf.function expects, except for the use of a Python list;我认为您需要调整索引的计算以确保它们不会越界,但撇开这些细节不谈,您的代码几乎是tf.function期望的,除了使用 Python 列表; you need to use TensorArray instead.您需要改用 TensorArray。

Something like this should work (the index calculations might not be entirely right):这样的事情应该可以工作(索引计算可能不完全正确):

@tf.function
def extract_sequential_patches(image, size, stride):
    image_length = tf.shape(image)[0]
    num_patches = (image_length - size) // stride + 1
    patches = tf.TensorArray(image.dtype, size=num_patches)
    for i in range(num_patches):
        start = i * stride
        end = start + size
        patches = patches.write(i, image[start:end, ...])
    return patches.stack()

You can find more details on why Python lists don't currently work in the autograph reference docs .您可以在签名参考文档中找到有关 Python 列表当前为何不起作用的更多详细信息。

That said, it might be faster to use the real/imag trick, if the kernel of extract_image_patches is optimized.也就是说,如果优化了 extract_image_patches 的内核,使用 real/imag 技巧可能会更快。 I recommend testing both approaches.我建议测试这两种方法。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM