简体   繁体   English

PyTorch 中是否有提取图像块的功能?

[英]Is there a function to extract image patches in PyTorch?

Given a batch of images, I'd like to extract all possible image patches, similar to a convolution.给定一批图像,我想提取所有可能的图像块,类似于卷积。 In TensorFlow, we can use tf.extract_image_patches to achieve this.在 TensorFlow 中,我们可以使用tf.extract_image_patches来实现这一点。 Is there an equivalent function in PyTorch? PyTorch 中是否有等效的功能?

Thank you.谢谢。

Unfortunately, there might not be a direct way to achieve your goal.不幸的是,可能没有直接的方法来实现您的目标。
But Tensor.unfold function might be a solution.但是 Tensor.unfold 函数可能是一个解决方案。
https://discuss.pytorch.org/t/how-to-extract-smaller-image-patches-3d/16837/2 https://discuss.pytorch.org/t/how-to-extract-smaller-image-patches-3d/16837/2
This website might help you.这个网站或许能帮到你。

Maybe this code example will help to understand how to use unfold , inspired by this thread linked by @gasoon, but a bit more verbose:也许此代码示例将有助于理解如何使用unfold ,灵感来自这个线程通过@gasoon链接,但更多的是有点冗长:

batch_size, n_channels, n_rows, n_cols = 32, 3, 64, 64
kernel_h, kernel_w = 7, 9
step = 5

x = torch.arange(batch_size*n_channels*n_rows*n_cols).view(batch_size, n_channels, n_rows, n_cols)

# unfold(dimension, size, step)
windows = x.unfold(2, kernel_h, step).unfold(3, kernel_w, step).permute(2, 3, 0, 1, 4, 5).reshape(-1, n_channels, kernel_h, kernel_w)
print(windows.shape)
# result: torch.Size([4608, 3, 7, 9]) = [n_windows, n_channels, krenel_h, kernel_w]

Spent a bit of time looking into this as well and I found this pytorch thread that was useful for me with PyTorch dev ptrblck (bless this dude) giving an equivalent pytorch version of the tensorflow function.也花了一些时间研究这个,我发现这个 pytorch 线程对我有用,PyTorch dev ptrblck (保佑这个家伙)提供了等效的 pytorch 版本的 tensorflow 函数。

I'll just repost the code (from user FloCF ) here for simplicity.为简单起见,我将在这里重新发布代码(来自用户FloCF )。

import math
import torch.nn.functional as F

def extract_image_patches(x, kernel, stride=1, dilation=1):
    # Do TF 'SAME' Padding
    b,c,h,w = x.shape
    h2 = math.ceil(h / stride)
    w2 = math.ceil(w / stride)
    pad_row = (h2 - 1) * stride + (kernel - 1) * dilation + 1 - h
    pad_col = (w2 - 1) * stride + (kernel - 1) * dilation + 1 - w
    x = F.pad(x, (pad_row//2, pad_row - pad_row//2, pad_col//2, pad_col - pad_col//2))
    
    # Extract patches
    patches = x.unfold(2, kernel, stride).unfold(3, kernel, stride)
    patches = patches.permute(0,4,5,1,2,3).contiguous()
    
    return patches.view(b,-1,patches.shape[-2], patches.shape[-1])

Give those guys a like on the PyTorch forum :)在 PyTorch 论坛上给这些人点赞 :)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM