简体   繁体   中英

Implement MATLAB's im2col 'sliding' in Python

Q: How to speed this up?

Below is my implementation of Matlab's im2col 'sliding' with the additional feature of returning every n'th column. The function takes an image (or any 2 dim array) and slides from left to right, top to bottom, picking off every overlapping sub-image of a given size, and returning an array whose columns are the sub-images.

import numpy as np

def im2col_sliding(image, block_size, skip=1):

    rows, cols = image.shape
    horz_blocks = cols - block_size[1] + 1
    vert_blocks = rows - block_size[0] + 1

    output_vectors = np.zeros((block_size[0] * block_size[1], horz_blocks * vert_blocks))
    itr = 0
    for v_b in xrange(vert_blocks):
        for h_b in xrange(horz_blocks):
            output_vectors[:, itr] = image[v_b: v_b + block_size[0], h_b: h_b + block_size[1]].ravel()
            itr += 1

    return output_vectors[:, ::skip]

example:

a = np.arange(16).reshape(4, 4)
print a
print im2col_sliding(a, (2, 2))  # return every overlapping 2x2 patch
print im2col_sliding(a, (2, 2), 4)  # return every 4th vector

returns:

[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]]
[[  0.   1.   2.   4.   5.   6.   8.   9.  10.]
 [  1.   2.   3.   5.   6.   7.   9.  10.  11.]
 [  4.   5.   6.   8.   9.  10.  12.  13.  14.]
 [  5.   6.   7.   9.  10.  11.  13.  14.  15.]]
[[  0.   5.  10.]
 [  1.   6.  11.]
 [  4.   9.  14.]
 [  5.  10.  15.]]

The performance isn't great, especially considering whether I call im2col_sliding(big_matrix, (8, 8)) (62001 columns) or im2col_sliding(big_matrix, (8, 8), 10) (6201 columns; keeping only every 10th vector) it will take the same amount of time [where big_matrix is of size 256 x 256].

I'm looking for any ideas to speed this up.

Approach #1

We could use some broadcasting here to get all the indices of all those sliding windows in one go and thus with indexing achieve a vectorized solution . This is inspired by Efficient Implementation of im2col and col2im .

Here's the implementation -

def im2col_sliding_broadcasting(A, BSZ, stepsize=1):
    # Parameters
    M,N = A.shape
    col_extent = N - BSZ[1] + 1
    row_extent = M - BSZ[0] + 1
    
    # Get Starting block indices
    start_idx = np.arange(BSZ[0])[:,None]*N + np.arange(BSZ[1])
    
    # Get offsetted indices across the height and width of input array
    offset_idx = np.arange(row_extent)[:,None]*N + np.arange(col_extent)
    
    # Get all actual indices & index into input array for final output
    return np.take (A,start_idx.ravel()[:,None] + offset_idx.ravel()[::stepsize])

Approach #2

Using newly gained knowledge of NumPy array strides that lets us create such sliding windows, we would have another efficient solution -

def im2col_sliding_strided(A, BSZ, stepsize=1):
    # Parameters
    m,n = A.shape
    s0, s1 = A.strides    
    nrows = m-BSZ[0]+1
    ncols = n-BSZ[1]+1
    shp = BSZ[0],BSZ[1],nrows,ncols
    strd = s0,s1,s0,s1
    
    out_view = np.lib.stride_tricks.as_strided(A, shape=shp, strides=strd)
    return out_view.reshape(BSZ[0]*BSZ[1],-1)[:,::stepsize]

Approach #3

The strided method listed in the previous approach has been incorporated into scikit-image module for a less messier, like so -

from skimage.util import view_as_windows as viewW

def im2col_sliding_strided_v2(A, BSZ, stepsize=1):
    return viewW(A, (BSZ[0],BSZ[1])).reshape(-1,BSZ[0]*BSZ[1]).T[:,::stepsize]

Sample runs -

In [106]: a      # Input array
Out[106]: 
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19]])

In [107]: im2col_sliding_broadcasting(a, (2,3))
Out[107]: 
array([[ 0,  1,  2,  5,  6,  7, 10, 11, 12],
       [ 1,  2,  3,  6,  7,  8, 11, 12, 13],
       [ 2,  3,  4,  7,  8,  9, 12, 13, 14],
       [ 5,  6,  7, 10, 11, 12, 15, 16, 17],
       [ 6,  7,  8, 11, 12, 13, 16, 17, 18],
       [ 7,  8,  9, 12, 13, 14, 17, 18, 19]])

In [108]: im2col_sliding_broadcasting(a, (2,3), stepsize=2)
Out[108]: 
array([[ 0,  2,  6, 10, 12],
       [ 1,  3,  7, 11, 13],
       [ 2,  4,  8, 12, 14],
       [ 5,  7, 11, 15, 17],
       [ 6,  8, 12, 16, 18],
       [ 7,  9, 13, 17, 19]])

Runtime test

In [183]: a = np.random.randint(0,255,(1024,1024))

In [184]: %timeit im2col_sliding(img, (8,8), skip=1)
     ...: %timeit im2col_sliding_broadcasting(img, (8,8), stepsize=1)
     ...: %timeit im2col_sliding_strided(img, (8,8), stepsize=1)
     ...: %timeit im2col_sliding_strided_v2(img, (8,8), stepsize=1)
     ...: 
1 loops, best of 3: 1.29 s per loop
1 loops, best of 3: 226 ms per loop
10 loops, best of 3: 84.5 ms per loop
10 loops, best of 3: 111 ms per loop

In [185]: %timeit im2col_sliding(img, (8,8), skip=4)
     ...: %timeit im2col_sliding_broadcasting(img, (8,8), stepsize=4)
     ...: %timeit im2col_sliding_strided(img, (8,8), stepsize=4)
     ...: %timeit im2col_sliding_strided_v2(img, (8,8), stepsize=4)
     ...: 
1 loops, best of 3: 1.31 s per loop
10 loops, best of 3: 104 ms per loop
10 loops, best of 3: 84.4 ms per loop
10 loops, best of 3: 109 ms per loop

Around 16x speedup there with the strided method over the original loopy version!

For sliding window over different image channels, we can use an updated version of the code provided by Divakar@ Implement MATLAB's im2col 'sliding' in Python , ie

import numpy as np
A = np.random.randint(0,9,(2,4,4)) # Sample input array
                    # Sample blocksize (rows x columns)
B = [2,2]
skip=[2,2]
# Parameters 
D,M,N = A.shape
col_extent = N - B[1] + 1
row_extent = M - B[0] + 1

# Get Starting block indices
start_idx = np.arange(B[0])[:,None]*N + np.arange(B[1])

# Generate Depth indeces
didx=M*N*np.arange(D)
start_idx=(didx[:,None]+start_idx.ravel()).reshape((-1,B[0],B[1]))

# Get offsetted indices across the height and width of input array
offset_idx = np.arange(row_extent)[:,None]*N + np.arange(col_extent)

# Get all actual indices & index into input array for final output
out = np.take (A,start_idx.ravel()[:,None] + offset_idx[::skip[0],::skip[1]].ravel())

Testing Sample Run

A=
[[[6 2 8 5]
[6 4 7 6]
[8 6 5 2]
[3 1 3 7]]

[[6 0 4 3]
[7 6 4 6]
[2 6 7 1]
[7 6 7 7]]]

out=
[6 8 8 5]
[2 5 6 2]
[6 7 3 3]
[4 6 1 7]
[6 4 2 7]
[0 3 6 1]
[7 4 7 7]
[6 6 6 7]

For further improving the performance (eg on convolution) we can also use batch implementation based on the extended code, provided by M Elyia@ Implement Matlab's im2col 'sliding' in python , ie

import numpy as np

A = np.arange(3*1*4*4).reshape(3,1,4,4)+1 # 3 Sample input array with 1 channel
B = [2,2] # Sample blocksize (rows x columns)
skip = [2,2]

# Parameters 
batch, D,M,N = A.shape
col_extent = N - B[1] + 1
row_extent = M - B[0] + 1

# Get batch block indices
batch_idx = np.arange(batch)[:, None, None] * D * M * N

# Get Starting block indices
start_idx = np.arange(B[0])[None, :,None]*N + np.arange(B[1])

# Generate Depth indeces
didx=M*N*np.arange(D)
start_idx=(didx[None, :, None]+start_idx.ravel()).reshape((-1,B[0],B[1]))

# Get offsetted indices across the height and width of input array
offset_idx = np.arange(row_extent)[None, :, None]*N + np.arange(col_extent)

# Get all actual indices & index into input array for final output
act_idx = (batch_idx + 
    start_idx.ravel()[None, :, None] + 
    offset_idx[:,::skip[0],::skip[1]].ravel())

out = np.take (A, act_idx)

Testing sample run:

A = 
[[[[ 1  2  3  4]
   [ 5  6  7  8]
   [ 9 10 11 12]
   [13 14 15 16]]]


 [[[17 18 19 20]
   [21 22 23 24]
   [25 26 27 28]
   [29 30 31 32]]]


 [[[33 34 35 36]
   [37 38 39 40]
   [41 42 43 44]
   [45 46 47 48]]]] 


out = 
[[[ 1  2  3  9 10 11]
  [ 2  3  4 10 11 12]
  [ 5  6  7 13 14 15]
  [ 6  7  8 14 15 16]]

 [[17 18 19 25 26 27]
  [18 19 20 26 27 28]
  [21 22 23 29 30 31]
  [22 23 24 30 31 32]]

 [[33 34 35 41 42 43]
  [34 35 36 42 43 44]
  [37 38 39 45 46 47]
  [38 39 40 46 47 48]]]

I've implemented fast solution using Numba JIT compiler. It gives speedup ranging from 5.67x to 3597x depending on block size and skip size.

Speedup means how much times is faster numba algorithm compared to original algorithm, eg speedup of 20x means that if original algorithm took 200ms then fast numba algorithm took 10ms .

My code needs installing following pip modules once through python -m pip install numpy numba timerit matplotlib .

Next is located code, then speedup plots, then console output of time measurements.

Try it online!

import numpy as np

# ----- Original Implementation -----

def im2col_sliding(image, block_size, skip = 1):
    rows, cols = image.shape
    horz_blocks = cols - block_size[1] + 1
    vert_blocks = rows - block_size[0] + 1
    
    if vert_blocks <= 0 or horz_blocks <= 0:
        return np.zeros((block_size[0] * block_size[1], 0), dtype = image.dtype)

    output_vectors = np.zeros((block_size[0] * block_size[1], horz_blocks * vert_blocks), dtype = image.dtype)
    itr = 0
    
    for v_b in range(vert_blocks):
        for h_b in range(horz_blocks):
            output_vectors[:, itr] = image[v_b: v_b + block_size[0], h_b: h_b + block_size[1]].ravel()
            itr += 1

    return output_vectors[:, ::skip]


# ----- Fast Numba Implementation -----
    
import numba

@numba.njit(cache = True)
def im2col_sliding_numba(image, block_size, skip = 1):
    assert skip >= 1
    rows, cols = image.shape
    horz_blocks = cols - block_size[1] + 1
    vert_blocks = rows - block_size[0] + 1
    
    if vert_blocks <= 0 or horz_blocks <= 0:
        return np.zeros((block_size[0] * block_size[1], 0), dtype = image.dtype)
    
    res = np.zeros((block_size[0] * block_size[1], (horz_blocks * vert_blocks + skip - 1) // skip), dtype = image.dtype)
    itr, to_skip, v_b = 0, 0, 0
    
    while True:
        v_b += to_skip // horz_blocks
        if v_b >= vert_blocks:
            break
        h_b_start = to_skip % horz_blocks
        h_cnt = (horz_blocks - h_b_start + skip - 1) // skip
        for i, h_b in zip(range(itr, itr + h_cnt), range(h_b_start, horz_blocks, skip)):
            ii = 0
            for iv in range(v_b, v_b + block_size[0]):
                for ih in range(h_b, h_b + block_size[1]):
                    res[ii, i] = image[iv, ih]
                    ii += 1
        to_skip = skip - (horz_blocks - h_b_start - skip * (h_cnt - 1))
        itr += h_cnt
        v_b += 1
        
    assert itr == res.shape[1]#, (itr, res.shape)

    return res


# ----- Testing -----

from timerit import Timerit
Timerit._default_asciimode = True

side = 256
a = np.random.randint(0, 256, (side, side), dtype = np.uint8)

stats = []

for block_size in [16, 8, 4, 2, 1]:
    for skip_size in [1, 2, 5, 11, 23]:
        print(f'block_size {block_size} skip_size {skip_size}', flush = True)
        for ifn, f in enumerate([im2col_sliding, im2col_sliding_numba]):
            print(f'{f.__name__}: ', end = '', flush = True)
            tim = Timerit(num = 3, verbose = 1)
            for i, t in enumerate(tim):
                if i == 0 and ifn == 1:
                    f(a, (block_size, block_size), skip_size)
                with t:
                    r = f(a, (block_size, block_size), skip_size)
            rt = tim.mean()
            if ifn == 0:
                bt, ba = rt, r
            else:
                assert np.array_equal(ba, r)
                print(f'speedup {round(bt / rt, 2)}x')
                stats.append({
                    'block_size': block_size,
                    'skip_size': skip_size,
                    'speedup': bt / rt,
                })

stats = sorted(stats, key = lambda e: e['speedup'])

import math, matplotlib, matplotlib.pyplot as plt

x = np.arange(len(stats))
y = np.array([e['speedup'] for e in stats])

plt.rcParams['figure.figsize'] = (12.8, 7.2)

for scale in ['linear', 'log']:
    plt.clf()
    plt.xlabel('iteration')
    plt.ylabel(f'speedup_{scale}')
    plt.yscale(scale)
    plt.scatter(x, y, marker = '.')
    for i in range(x.size):
        plt.annotate(
            (f"b{str(stats[i]['block_size']).zfill(2)}s{str(stats[i]['skip_size']).zfill(2)}\n" +
             f"x{round(stats[i]['speedup'], 2 if stats[i]['speedup'] < 100 else 1 if stats[i]['speedup'] < 1000 else None)}"),
            (x[i], y[i]), fontsize = 'small',
        )
    plt.subplots_adjust(left = 0.055, right = 0.99, bottom = 0.08, top = 0.99)
    plt.xlim(left = -0.1)
    if scale == 'linear':
        ymin, ymax = np.amin(y), np.amax(y)
        plt.ylim((ymin - (ymax - ymin) * 0.02, ymax + (ymax - ymin) * 0.05))
        plt.yticks([ymin] + [e for e in plt.yticks()[0] if ymin + 0.01 < e < ymax - 0.01] + [ymax])
        #plt.gca().get_yaxis().set_major_formatter(matplotlib.ticker.FormatStrFormatter('%.1f'))
    plt.savefig(f'im2col_numba_{scale}.png', dpi = 150)
    plt.show()

Next plots have iteration as x axis, speedup as y axis, first plot has linear y axis, second plot has logarithmic y axis. Also each point has labels bXXsYYxZZ where XX equals to block size, YY equals to skip (step) size, ZZ equals to speedup.

Linear plot:

线性图

Logarithmic plot:

对数图

Console output:

block_size 16 skip_size 1
im2col_sliding: Timed best=549.069 ms, mean=549.069 +- 0.0 ms
im2col_sliding_numba: Timed best=96.841 ms, mean=96.841 +- 0.0 ms
speedup 5.67x
block_size 16 skip_size 2
im2col_sliding: Timed best=559.396 ms, mean=559.396 +- 0.0 ms
im2col_sliding_numba: Timed best=71.132 ms, mean=71.132 +- 0.0 ms
speedup 7.86x
block_size 16 skip_size 5
im2col_sliding: Timed best=561.030 ms, mean=561.030 +- 0.0 ms
im2col_sliding_numba: Timed best=15.000 ms, mean=15.000 +- 0.0 ms
speedup 37.4x
block_size 16 skip_size 11
im2col_sliding: Timed best=559.045 ms, mean=559.045 +- 0.0 ms
im2col_sliding_numba: Timed best=6.719 ms, mean=6.719 +- 0.0 ms
speedup 83.21x
block_size 16 skip_size 23
im2col_sliding: Timed best=562.462 ms, mean=562.462 +- 0.0 ms
im2col_sliding_numba: Timed best=2.514 ms, mean=2.514 +- 0.0 ms
speedup 223.72x
block_size 8 skip_size 1
im2col_sliding: Timed best=373.790 ms, mean=373.790 +- 0.0 ms
im2col_sliding_numba: Timed best=17.441 ms, mean=17.441 +- 0.0 ms
speedup 21.43x
block_size 8 skip_size 2
im2col_sliding: Timed best=375.858 ms, mean=375.858 +- 0.0 ms
im2col_sliding_numba: Timed best=8.791 ms, mean=8.791 +- 0.0 ms
speedup 42.75x
block_size 8 skip_size 5
im2col_sliding: Timed best=376.767 ms, mean=376.767 +- 0.0 ms
im2col_sliding_numba: Timed best=3.115 ms, mean=3.115 +- 0.0 ms
speedup 120.94x
block_size 8 skip_size 11
im2col_sliding: Timed best=378.284 ms, mean=378.284 +- 0.0 ms
im2col_sliding_numba: Timed best=1.406 ms, mean=1.406 +- 0.0 ms
speedup 268.97x
block_size 8 skip_size 23
im2col_sliding: Timed best=376.268 ms, mean=376.268 +- 0.0 ms
im2col_sliding_numba: Timed best=661.404 us, mean=661.404 +- 0.0 us
speedup 568.89x
block_size 4 skip_size 1
im2col_sliding: Timed best=378.813 ms, mean=378.813 +- 0.0 ms
im2col_sliding_numba: Timed best=4.950 ms, mean=4.950 +- 0.0 ms
speedup 76.54x
block_size 4 skip_size 2
im2col_sliding: Timed best=377.620 ms, mean=377.620 +- 0.0 ms
im2col_sliding_numba: Timed best=2.119 ms, mean=2.119 +- 0.0 ms
speedup 178.24x
block_size 4 skip_size 5
im2col_sliding: Timed best=374.792 ms, mean=374.792 +- 0.0 ms
im2col_sliding_numba: Timed best=854.986 us, mean=854.986 +- 0.0 us
speedup 438.36x
block_size 4 skip_size 11
im2col_sliding: Timed best=373.296 ms, mean=373.296 +- 0.0 ms
im2col_sliding_numba: Timed best=415.028 us, mean=415.028 +- 0.0 us
speedup 899.45x
block_size 4 skip_size 23
im2col_sliding: Timed best=374.075 ms, mean=374.075 +- 0.0 ms
im2col_sliding_numba: Timed best=219.491 us, mean=219.491 +- 0.0 us
speedup 1704.28x
block_size 2 skip_size 1
im2col_sliding: Timed best=377.698 ms, mean=377.698 +- 0.0 ms
im2col_sliding_numba: Timed best=1.477 ms, mean=1.477 +- 0.0 ms
speedup 255.67x
block_size 2 skip_size 2
im2col_sliding: Timed best=378.155 ms, mean=378.155 +- 0.0 ms
im2col_sliding_numba: Timed best=841.298 us, mean=841.298 +- 0.0 us
speedup 449.49x
block_size 2 skip_size 5
im2col_sliding: Timed best=376.381 ms, mean=376.381 +- 0.0 ms
im2col_sliding_numba: Timed best=392.541 us, mean=392.541 +- 0.0 us
speedup 958.83x
block_size 2 skip_size 11
im2col_sliding: Timed best=374.720 ms, mean=374.720 +- 0.0 ms
im2col_sliding_numba: Timed best=193.093 us, mean=193.093 +- 0.0 us
speedup 1940.62x
block_size 2 skip_size 23
im2col_sliding: Timed best=378.092 ms, mean=378.092 +- 0.0 ms
im2col_sliding_numba: Timed best=105.101 us, mean=105.101 +- 0.0 us
speedup 3597.42x
block_size 1 skip_size 1
im2col_sliding: Timed best=203.410 ms, mean=203.410 +- 0.0 ms
im2col_sliding_numba: Timed best=686.335 us, mean=686.335 +- 0.0 us
speedup 296.37x
block_size 1 skip_size 2
im2col_sliding: Timed best=202.865 ms, mean=202.865 +- 0.0 ms
im2col_sliding_numba: Timed best=361.255 us, mean=361.255 +- 0.0 us
speedup 561.56x
block_size 1 skip_size 5
im2col_sliding: Timed best=200.929 ms, mean=200.929 +- 0.0 ms
im2col_sliding_numba: Timed best=164.740 us, mean=164.740 +- 0.0 us
speedup 1219.68x
block_size 1 skip_size 11
im2col_sliding: Timed best=202.163 ms, mean=202.163 +- 0.0 ms
im2col_sliding_numba: Timed best=96.791 us, mean=96.791 +- 0.0 us
speedup 2088.65x
block_size 1 skip_size 23
im2col_sliding: Timed best=202.492 ms, mean=202.492 +- 0.0 ms
im2col_sliding_numba: Timed best=64.527 us, mean=64.527 +- 0.0 us
speedup 3138.1x

I don't think that you can do better. Clearly, you have to run a loop of size

cols - block_size[1] * rows - block_size[0]

But you're taking a 3, 3 patch in your example, not a 2, 2.

You can also add further optimization to M Eliya's answer (although not that significant)

Instead of "applying" skip at the very end, you can apply it when generating offset arrays, so instead of:

# Get offsetted indices across the height and width of input array
offset_idx = np.arange(row_extent)[:,None]*N + np.arange(col_extent)

# Get all actual indices & index into input array for final output
out = np.take (A,start_idx.ravel()[:,None] + offset_idx[::skip[0],::skip[1]].ravel())

You would add skips by using step parameter of numpy's arange function:

# Get offsetted indices across the height and width of input array and add skips
offset_idx = np.arange(row_extent, step=skip[0])[:, None] * N + np.arange(col_extent, step=skip[1])

and afterwards just add the offset array without [::] indexing

# Get all actual indices & index into input array for final output

out = np.take(A, start_idx.ravel()[:, None] + offset_idx.ravel())

On small skip values it barely saves any time:

In[25]:
A = np.random.randint(0,9,(3, 1024, 1024))
B = [2, 2]
skip = [2, 2]

In[26]: %timeit im2col(A, B, skip)
10 loops, best of 3: 19.7 ms per loop

In[27]: %timeit im2col_optimized(A, B, skip)
100 loops, best of 3: 17.5 ms per loop

However with larger skip values it saves a bit more time:

In[28]: skip = [10, 10]
In[29]: %timeit im2col(A, B, skip)
100 loops, best of 3: 3.85 ms per loop

In[30]: %timeit im2col_optimized(A, B, skip)
1000 loops, best of 3: 1.02 ms per loop

A = np.random.randint(0,9,(3, 2000, 2000))
B = [10, 10]
skip = [10, 10]

In[43]: %timeit im2col(A, B, skip)
10 loops, best of 3: 87.8 ms per loop

In[44]: %timeit im2col_optimized(A, B, skip)
10 loops, best of 3: 76.3 ms per loop

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM