簡體   English   中英

了解用於從圖像中提取補丁的 tf.extract_image_patches

[英]Understanding tf.extract_image_patches for extracting patches from an image

我在 tensorflow API 中找到了以下方法tf.extract_image_patches ,但我不清楚它的功能。

假設batch_size = 1 ,圖像大小為225x225x3 ,我們想要提取大小為32x32補丁。

這個函數的行為究竟如何? 具體來說,文檔提到輸出張量的維度為[batch, out_rows, out_cols, ksize_rows * ksize_cols * depth] ,但沒有提到out_rowsout_cols是什么。

理想情況下,給定大小為1x225x225x3 (其中 1 是批量大小)的輸入圖像張量,我希望能夠獲得Kx32x32x3作為輸出,其中K是補丁的總數, 32x32x3是每個補丁的尺寸。 tensorflow 中是否有一些東西已經實現了這一點?

以下是該方法的工作原理:

  • ksizes用於決定每個補丁的尺寸,或者換句話說,每個補丁應該包含多少像素。
  • strides表示原始圖像中一個補丁的開始和下一個連續補丁的開始之間的間隙長度。
  • rates是一個數字,本質上意味着我們的補丁應該按照原始圖像中每個以我們補丁結束的連續像素為單位的像素rates跳躍。 (下面的例子有助於說明這一點。)
  • padding要么是“VALID”,這意味着每個補丁都必須完全包含在圖像中,或者是“SAME”,這意味着補丁可以不完整(剩余的像素將用零填充)。

下面是一些帶有輸出的示例代碼,以幫助演示它是如何工作的:

import tensorflow as tf

n = 10
# images is a 1 x 10 x 10 x 1 array that contains the numbers 1 through 100 in order
images = [[[[x * n + y + 1] for y in range(n)] for x in range(n)]]

# We generate four outputs as follows:
# 1. 3x3 patches with stride length 5
# 2. Same as above, but the rate is increased to 2
# 3. 4x4 patches with stride length 7; only one patch should be generated
# 4. Same as above, but with padding set to 'SAME'
with tf.Session() as sess:
  print tf.extract_image_patches(images=images, ksizes=[1, 3, 3, 1], strides=[1, 5, 5, 1], rates=[1, 1, 1, 1], padding='VALID').eval(), '\n\n'
  print tf.extract_image_patches(images=images, ksizes=[1, 3, 3, 1], strides=[1, 5, 5, 1], rates=[1, 2, 2, 1], padding='VALID').eval(), '\n\n'
  print tf.extract_image_patches(images=images, ksizes=[1, 4, 4, 1], strides=[1, 7, 7, 1], rates=[1, 1, 1, 1], padding='VALID').eval(), '\n\n'
  print tf.extract_image_patches(images=images, ksizes=[1, 4, 4, 1], strides=[1, 7, 7, 1], rates=[1, 1, 1, 1], padding='SAME').eval()

輸出:

[[[[ 1  2  3 11 12 13 21 22 23]
   [ 6  7  8 16 17 18 26 27 28]]

  [[51 52 53 61 62 63 71 72 73]
   [56 57 58 66 67 68 76 77 78]]]]


[[[[  1   3   5  21  23  25  41  43  45]
   [  6   8  10  26  28  30  46  48  50]]

  [[ 51  53  55  71  73  75  91  93  95]
   [ 56  58  60  76  78  80  96  98 100]]]]


[[[[ 1  2  3  4 11 12 13 14 21 22 23 24 31 32 33 34]]]]


[[[[  1   2   3   4  11  12  13  14  21  22  23  24  31  32  33  34]
   [  8   9  10   0  18  19  20   0  28  29  30   0  38  39  40   0]]

  [[ 71  72  73  74  81  82  83  84  91  92  93  94   0   0   0   0]
   [ 78  79  80   0  88  89  90   0  98  99 100   0   0   0   0   0]]]]

因此,例如,我們的第一個結果如下所示:

 *  *  *  4  5  *  *  *  9 10 
 *  *  * 14 15  *  *  * 19 20 
 *  *  * 24 25  *  *  * 29 30 
31 32 33 34 35 36 37 38 39 40 
41 42 43 44 45 46 47 48 49 50 
 *  *  * 54 55  *  *  * 59 60 
 *  *  * 64 65  *  *  * 69 70 
 *  *  * 74 75  *  *  * 79 80 
81 82 83 84 85 86 87 88 89 90 
91 92 93 94 95 96 97 98 99 100 

如您所見,我們有 2 行 2 列的補丁,即out_rowsout_cols

為了擴展 Neal 的詳細答案,在使用“SAME”時零填充有很多微妙之處,因為extract_image_patches 會盡可能地將圖像中的補丁居中。 根據步幅,頂部和左側可能有填充,也可能沒有,第一個補丁不一定從左上角開始。

例如,擴展前面的例子:

print tf.extract_image_patches(images, [1, 3, 3, 1], [1, n, n, 1], [1, 1, 1, 1], 'SAME').eval()[0]

當步長為 n=1 時,圖像四周用零填充,第一個補丁從填充開始。 其他步幅僅在右側和底部填充圖像,或者根本不填充。 當步長為 n=10 時,單個補丁從元素 34(在圖像的中間)開始。

tf.extract_image_patches 由本答案中所述的特征庫實現。 您可以研究該代碼以准確了解補丁位置和填充是如何計算的。

簡介

在這里,我想展示一個相當簡單的演示,以將tf.image.extract_patches圖像本身一起使用。 我發現該方法的實現量相當小,使用具有適當可視化的實際圖像,所以就在這里。

我們將使用的圖像大小為 (256, 256, 3)。 我們將提取的補丁的形狀為 (128, 128, 3)。 這意味着我們將從圖像中檢索 4 個圖塊。

使用的數據

我將使用花數據集 由於這個答案需要一點數據管道,我將在這里鏈接我的kaggle 內核,它討論使用tf.data.Dataset API 使用數據集。

完成后,我們將瀏覽以下代碼片段。

images, _ = next(iter(train_ds.take(1)))

image = images[0]
plt.imshow(image.numpy().astype("uint8"))

花

在這里,我們從一批圖像中取出一張圖像並按原樣對其進行可視化。

image = tf.expand_dims(image,0) # To create the batch information
patches = tf.image.extract_patches(images=image,
                                   sizes=[1, 128, 128, 1],
                                   strides=[1, 128, 128, 1],
                                   rates=[1, 1, 1, 1],
                                   padding='VALID')

使用這個片段,我們從大小為 (256,256) 的圖像中提取大小為 (128,128) 的塊。 這直接轉化為我希望將圖像分成 4 個圖塊的事實。

可視化

plt.figure(figsize=(10, 10))
for imgs in patches:
    count = 0
    for r in range(2):
        for c in range(2):
            ax = plt.subplot(2, 2, count+1)
            plt.imshow(tf.reshape(imgs[r,c],shape=(128,128,3)).numpy().astype("uint8"))
            count += 1

花的裂痕

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM