简体   繁体   中英

tensorflow tf.extract_image_patches

The official tensorflow documentation for the extract_image_patches function says:

tf.extract_image_patches(
    images,
    ksizes,
    strides,
    rates,
    padding,
    name=None
)

I understood all the required arguments except the rates parameter. The reason for this is perhaps the explanation given in the api docs:

rates: A list of ints that has length >= 4. 1-D of length 4. 
Must be: [1, rate_rows, rate_cols, 1]. This is the input stride, 
specifying how far two consecutive patch samples are in the input. 
Equivalent to extracting patches with 
patch_sizes_eff = patch_sizes + (patch_sizes - 1) * (rates - 1), 
followed by subsampling them spatially by a factor of rates. This is 
equivalent to rate in dilated (a.k.a. Atrous) convolutions.

This only confuses me more as in what is the difference between the strides and the rates? I would be grateful if someone could explain with a simple example and in simple language what the rates parameter is? I saw a few examples of extracting image patches from a given image and in all of them, the value used was [1, 1, 1, 1] . Should it always be 1? Need help please.

Here is how the method works:

  • ksizes is used to decide the dimensions of each patch, or in other words, how many pixels each patch should contain.
  • strides denotes the length of the gap between the start of one patch and the start of the next consecutive patch within the original image.
  • rates is a number that essentially means our patch should jump by rates pixels in the original image for each consecutive pixel that ends up in our patch. (The example below helps illustrate this.)
  • padding is either "VALID", which means every patch must be fully contained in the image, or "SAME", which means patches are allowed to be incomplete (the remaining pixels will be filled in with zeroes).

Here is some sample code with output to help demonstrate how it works:

import tensorflow as tf

n = 10
# images is a 1 x 10 x 10 x 1 array that contains the numbers 1 through 100 in order
images = [[[[x * n + y + 1] for y in range(n)] for x in range(n)]]

# We generate four outputs as follows:
# 1. 3x3 patches with stride length 5
# 2. Same as above, but the rate is increased to 2
# 3. 4x4 patches with stride length 7; only one patch should be generated
# 4. Same as above, but with padding set to 'SAME'
with tf.Session() as sess:
  print tf.extract_image_patches(images=images, ksizes=[1, 3, 3, 1], strides=[1, 5, 5, 1], rates=[1, 1, 1, 1], padding='VALID').eval(), '\n\n'
  print tf.extract_image_patches(images=images, ksizes=[1, 3, 3, 1], strides=[1, 5, 5, 1], rates=[1, 2, 2, 1], padding='VALID').eval(), '\n\n'
  print tf.extract_image_patches(images=images, ksizes=[1, 4, 4, 1], strides=[1, 7, 7, 1], rates=[1, 1, 1, 1], padding='VALID').eval(), '\n\n'
  print tf.extract_image_patches(images=images, ksizes=[1, 4, 4, 1], strides=[1, 7, 7, 1], rates=[1, 1, 1, 1], padding='SAME').eval()

Output:

[[[[ 1  2  3 11 12 13 21 22 23]
   [ 6  7  8 16 17 18 26 27 28]]

  [[51 52 53 61 62 63 71 72 73]
   [56 57 58 66 67 68 76 77 78]]]]


[[[[  1   3   5  21  23  25  41  43  45]
   [  6   8  10  26  28  30  46  48  50]]

  [[ 51  53  55  71  73  75  91  93  95]
   [ 56  58  60  76  78  80  96  98 100]]]]


[[[[ 1  2  3  4 11 12 13 14 21 22 23 24 31 32 33 34]]]]


[[[[  1   2   3   4  11  12  13  14  21  22  23  24  31  32  33  34]
   [  8   9  10   0  18  19  20   0  28  29  30   0  38  39  40   0]]

  [[ 71  72  73  74  81  82  83  84  91  92  93  94   0   0   0   0]
   [ 78  79  80   0  88  89  90   0  98  99 100   0   0   0   0   0]]]]

So, for example, our first result looks like the following:

 *  *  *  4  5  *  *  *  9 10 
 *  *  * 14 15  *  *  * 19 20 
 *  *  * 24 25  *  *  * 29 30 
31 32 33 34 35 36 37 38 39 40 
41 42 43 44 45 46 47 48 49 50 
 *  *  * 54 55  *  *  * 59 60 
 *  *  * 64 65  *  *  * 69 70 
 *  *  * 74 75  *  *  * 79 80 
81 82 83 84 85 86 87 88 89 90 
91 92 93 94 95 96 97 98 99 100 

As you can see, we have 2 rows and 2 columns worth of patches, which are what out_rows and out_cols are.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM