[英]Tensorflow: extract sequential patches from a complex tensor of arbitrary length
我試圖弄清楚如何從長度可變的復值張量中提取連續的補丁。 提取是作為tf.data
管道的一部分執行的。
如果張量不復雜,我會像在這個答案中一樣使用tf.image.extract_image_patches
。
但是,該函數不適用於復雜張量。 我嘗試了以下技術,但它失敗了,因為張量的長度未知。
def extract_sequential_patches(image):
image_length = tf.shape(image)[0]
num_patches = image_length // (128 // 4)
patches = []
for i in range(num_patches):
start = i * 128
end = start + 128
patches.append(image[start:end, ...])
return tf.stack(patches)
但是我收到錯誤:
InaccessibleTensorError: The tensor 'Tensor("strided_slice:0", shape=(None, 512, 2), dtype=complex64)' cannot be accessed here: it is defined in another function or code block. Use return values, explicit Python locals or TensorFlow collections to access it. Defined in: FuncGraph(name=while_body_2100, id=140313967335120)
我嘗試過用@tf.function
自由裝飾
我認為您需要調整索引的計算以確保它們不會越界,但撇開這些細節不談,您的代碼幾乎是tf.function
期望的,除了使用 Python 列表; 您需要改用 TensorArray。
這樣的事情應該可以工作(索引計算可能不完全正確):
@tf.function
def extract_sequential_patches(image, size, stride):
image_length = tf.shape(image)[0]
num_patches = (image_length - size) // stride + 1
patches = tf.TensorArray(image.dtype, size=num_patches)
for i in range(num_patches):
start = i * stride
end = start + size
patches = patches.write(i, image[start:end, ...])
return patches.stack()
您可以在簽名參考文檔中找到有關 Python 列表當前為何不起作用的更多詳細信息。
也就是說,如果優化了 extract_image_patches 的內核,使用 real/imag 技巧可能會更快。 我建議測試這兩種方法。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.