简体   繁体   中英

Is there a function to extract image patches in PyTorch?

Given a batch of images, I'd like to extract all possible image patches, similar to a convolution. In TensorFlow, we can use tf.extract_image_patches to achieve this. Is there an equivalent function in PyTorch?

Thank you.

Unfortunately, there might not be a direct way to achieve your goal.
But Tensor.unfold function might be a solution.
https://discuss.pytorch.org/t/how-to-extract-smaller-image-patches-3d/16837/2
This website might help you.

Maybe this code example will help to understand how to use unfold , inspired by this thread linked by @gasoon, but a bit more verbose:

batch_size, n_channels, n_rows, n_cols = 32, 3, 64, 64
kernel_h, kernel_w = 7, 9
step = 5

x = torch.arange(batch_size*n_channels*n_rows*n_cols).view(batch_size, n_channels, n_rows, n_cols)

# unfold(dimension, size, step)
windows = x.unfold(2, kernel_h, step).unfold(3, kernel_w, step).permute(2, 3, 0, 1, 4, 5).reshape(-1, n_channels, kernel_h, kernel_w)
print(windows.shape)
# result: torch.Size([4608, 3, 7, 9]) = [n_windows, n_channels, krenel_h, kernel_w]

Spent a bit of time looking into this as well and I found this pytorch thread that was useful for me with PyTorch dev ptrblck (bless this dude) giving an equivalent pytorch version of the tensorflow function.

I'll just repost the code (from user FloCF ) here for simplicity.

import math
import torch.nn.functional as F

def extract_image_patches(x, kernel, stride=1, dilation=1):
    # Do TF 'SAME' Padding
    b,c,h,w = x.shape
    h2 = math.ceil(h / stride)
    w2 = math.ceil(w / stride)
    pad_row = (h2 - 1) * stride + (kernel - 1) * dilation + 1 - h
    pad_col = (w2 - 1) * stride + (kernel - 1) * dilation + 1 - w
    x = F.pad(x, (pad_row//2, pad_row - pad_row//2, pad_col//2, pad_col - pad_col//2))
    
    # Extract patches
    patches = x.unfold(2, kernel, stride).unfold(3, kernel, stride)
    patches = patches.permute(0,4,5,1,2,3).contiguous()
    
    return patches.view(b,-1,patches.shape[-2], patches.shape[-1])

Give those guys a like on the PyTorch forum :)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM