简体   繁体   中英

col2im implementation in ConvNet

I'm trying to implement a CNN only using numpy.

While doing the backpropagation, I found out that I had to use col2im in order to reshape dx , so I checked the implementation from https://github.com/huyouare/CS231n/blob/master/assignment2/cs231n/im2col.py .

import numpy as np


def get_im2col_indices(x_shape, field_height, field_width, padding=1, stride=1):
  # First figure out what the size of the output should be
  N, C, H, W = x_shape
  assert (H + 2 * padding - field_height) % stride == 0
  assert (W + 2 * padding - field_height) % stride == 0
  out_height = (H + 2 * padding - field_height) / stride + 1
  out_width = (W + 2 * padding - field_width) / stride + 1

  i0 = np.repeat(np.arange(field_height), field_width)
  i0 = np.tile(i0, C)
  i1 = stride * np.repeat(np.arange(out_height), out_width)
  j0 = np.tile(np.arange(field_width), field_height * C)
  j1 = stride * np.tile(np.arange(out_width), out_height)
  i = i0.reshape(-1, 1) + i1.reshape(1, -1)
  j = j0.reshape(-1, 1) + j1.reshape(1, -1)

  k = np.repeat(np.arange(C), field_height * field_width).reshape(-1, 1)

  return (k, i, j)


def im2col_indices(x, field_height, field_width, padding=1, stride=1):
  """ An implementation of im2col based on some fancy indexing """
  # Zero-pad the input
  p = padding
  x_padded = np.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant')

  k, i, j = get_im2col_indices(x.shape, field_height, field_width, padding,
                               stride)

  cols = x_padded[:, k, i, j]
  C = x.shape[1]
  cols = cols.transpose(1, 2, 0).reshape(field_height * field_width * C, -1)
  return cols


def col2im_indices(cols, x_shape, field_height=3, field_width=3, padding=1,
                   stride=1):
  """ An implementation of col2im based on fancy indexing and np.add.at """
  N, C, H, W = x_shape
  H_padded, W_padded = H + 2 * padding, W + 2 * padding
  x_padded = np.zeros((N, C, H_padded, W_padded), dtype=cols.dtype)
  k, i, j = get_im2col_indices(x_shape, field_height, field_width, padding,
                               stride)
  cols_reshaped = cols.reshape(C * field_height * field_width, -1, N)
  cols_reshaped = cols_reshaped.transpose(2, 0, 1)
  np.add.at(x_padded, (slice(None), k, i, j), cols_reshaped)
  if padding == 0:
    return x_padded
  return x_padded[:, :, padding:-padding, padding:-padding]

pass

I expected when I put X into im2col_indices , and putting that output back to col2im_indices will return the same X , but it didn't.

I don't understand what col2im actually does.

If I'm right, the output is not the same X because each cell of X is converted to multiple col s, and it's been multiplied during im2col_indices .

Say you have a simple image X like this

 1 2 3
 4 5 6
 7 8 9

and you convert it with kernel size 3, stride 1, and the same padding, the result would be

0 0 0 0 1 2 0 4 5
0 0 0 1 2 3 4 5 6
0 0 0 2 3 0 5 6 0
0 1 2 0 4 5 0 7 8
1 2 3 4 5 6 7 8 9
2 3 0 5 6 0 8 9 0
0 4 5 0 7 8 0 0 0
4 5 6 7 8 9 0 0 0
5 6 0 8 9 0 0 0 0
* *   * *

as you can see, the first cell with value 1 shows up in four col s: 0, 1, 3, 4.

im2col_indices first zero initialize a image with padded size, and then add each col to it. Focus on the first cell, the process should be like

1.zero initialized image

0 0 0 0 0
0 0 0 0 0
0 0 0 0 0
0 0 0 0 0
0 0 0 0 0

2.add col 0

0 0 0 0 0     0 0 0 - -     0 0 0 0 0
0 0 0 0 0     0 1 2 - -     0 1 2 0 0
0 0 0 0 0  +  0 4 5 - -  =  0 4 5 0 0
0 0 0 0 0     - - - - -     0 0 0 0 0
0 0 0 0 0     - - - - -     0 0 0 0 0

3.add col 1

0 0 0 0 0     - 0 0 0 -     0  0  0  0  0
0 1 2 0 0     - 1 2 3 -     0  2  4  3  0
0 4 5 0 0  +  - 4 5 6 -  =  0  8 10  6  0
0 0 0 0 0     - - - - -     0  0  0  0  0
0 0 0 0 0     - - - - -     0  0  0  0  0

4.add col 3

0  0  0  0  0     - - - - -     0  0  0  0  0
0  2  4  3  0     0 1 2 - -     0  3  6  3  0
0  8 10  6  0  +  0 4 5 - -  =  0 12 15  6  0
0  0  0  0  0     0 7 8 - -     0  7  8  0  0 
0  0  0  0  0     - - - - -     0  0  0  0  0

5.add col 4

0  0  0  0  0     - - - - -     0  0  0  0  0
0  3  6  3  0     - 1 2 3 -     0  4  8  6  0
0 12 15  6  0  +  - 4 5 6 -  =  0 16 20 12  0
0  7  8  0  0     - 7 8 9 -     0 14 16  9  0
0  0  0  0  0     - - - - -     0  0  0  0  0 

The first cell is multiplied by 4 when converted back. For this simple image, col2im_indices(im2col_indices(X)) should give you

 4  12  12
24  45  36
28  48  36

Comparing to the original image, the four corner cells 1 3 7 9 are multiplied by 4, the four edge cells 2 4 6 8 are multiplied by 6 and the center cell 5 is multiplied by 9.

For large images, most of the cells will be multiplied by 9 and I think it roughly means your learning rate is actually 9 times larger than you think.

Replying this 2 years old thread, it may help someone in future.

Here is my understanding. In the CNN back propagation context, col2im matrix is the product of filters and back propagated error(dout). It must be noted the matrix is already a product of two matrices, unlike in im2col use case in the forward pass where we have just stretched the input into im2col matrix ready for multiplication(convolution). Due to this difference between im2col and col2im, in col2im we need to add back propagated error to all the contributing input indices.

Let us consider an example of 1x5x5 input, single 1x3x3 filter, 0 padding, stride 1. The indices of the input will look like:

[0,0] [0,1] [0,2] [0,3] [0,4]
[1,0] [1,1] [1,2] [1,3] [1,4]
[2,0] [2,1] [2,2] [2,3] [2,4]
[3,0] [3,1] [3,2] [3,3] [3,4]
[4,0] [4,1] [4,2] [4,3] [4,4]

The resulting 9x9 im2col indices computed for the forward propagation matrix multiplication will look like:

im2col indices

<-----------------------  9 ----------------------------->
[ 0] [0,0] [0,1] [0,2] [1,0] [1,1] [1,2] [2,0] [2,1] [2,2] 
[ 1] [0,1] [0,2] [0,3] [1,1] [1,2] [1,3] [2,1] [2,2] [2,3] 
[ 2] [0,2] [0,3] [0,4] [1,2] [1,3] [1,4] [2,2] [2,3] [2,4] 
[ 3] [1,0] [1,1] [1,2] [2,0] [2,1] [2,2] [3,0] [3,1] [3,2] 
[ 4] [1,1] [1,2] [1,3] [2,1] [2,2] [2,3] [3,1] [3,2] [3,3] 
[ 5] [1,2] [1,3] [1,4] [2,2] [2,3] [2,4] [3,2] [3,3] [3,4] 
[ 6] [2,0] [2,1] [2,2] [3,0] [3,1] [3,2] [4,0] [4,1] [4,2] 
[ 7] [2,1] [2,2] [2,3] [3,1] [3,2] [3,3] [4,1] [4,2] [4,3] 
[ 8] [2,2] [2,3] [2,4] [3,2] [3,3] [3,4] [4,2] [4,3] [4,4] 

In the backward pass when we generate col2im matrix by multiplying back propagated error dout and the filter the resulting indices which appears like above is already the result of multiplication. When we convert this back to the input error we need to add corresponding indices in a given location in the input error array.

For example:

input_error[0,0] = im2col_error[0,0]
input_error[0,1] = im2col_error[0,1] + im2col_error[1,0]
input_error[0,2] = im2col_error[0,2] + im2col_error[1,1] + im2col_error[2,0]
....
....

This is evident from the indices matrix above.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM