簡體   English   中英

Tensorflow:從任意長度的復雜張量中提取連續的補丁

[英]Tensorflow: extract sequential patches from a complex tensor of arbitrary length

我試圖弄清楚如何從長度可變的復值張量中提取連續的補丁。 提取是作為tf.data管道的一部分執行的。

如果張量不復雜,我會像在這個答案中一樣使用tf.image.extract_image_patches

但是,該函數不適用於復雜張量。 我嘗試了以下技術,但它失敗了,因為張量的長度未知。

def extract_sequential_patches(image):
    image_length = tf.shape(image)[0]
    num_patches = image_length // (128 // 4)
    patches = []
    for i in range(num_patches):
        start = i * 128
        end = start + 128
        patches.append(image[start:end, ...])
    return tf.stack(patches)

但是我收到錯誤:

InaccessibleTensorError: The tensor 'Tensor("strided_slice:0", shape=(None, 512, 2), dtype=complex64)' cannot be accessed here: it is defined in another function or code block. Use return values, explicit Python locals or TensorFlow collections to access it. Defined in: FuncGraph(name=while_body_2100, id=140313967335120)

我嘗試過用@tf.function自由裝飾

我認為您需要調整索引的計算以確保它們不會越界,但撇開這些細節不談,您的代碼幾乎是tf.function期望的,除了使用 Python 列表; 您需要改用 TensorArray。

這樣的事情應該可以工作(索引計算可能不完全正確):

@tf.function
def extract_sequential_patches(image, size, stride):
    image_length = tf.shape(image)[0]
    num_patches = (image_length - size) // stride + 1
    patches = tf.TensorArray(image.dtype, size=num_patches)
    for i in range(num_patches):
        start = i * stride
        end = start + size
        patches = patches.write(i, image[start:end, ...])
    return patches.stack()

您可以在簽名參考文檔中找到有關 Python 列表當前為何不起作用的更多詳細信息。

也就是說,如果優化了 extract_image_patches 的內核,使用 real/imag 技巧可能會更快。 我建議測試這兩種方法。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM