简体   繁体   中英

How to extract patches from an image in pytorch?

I want to extract image patches from an image with patch size 128 and stride 32, so I have this code, but it gives me an error:

from PIL import Image 
img = Image.open("cat.jpg")
x = transforms.ToTensor()(img)

x = x.unsqueeze(0)

size = 128 # patch size
stride = 32 # patch stride
patches = x.unfold(1, size, stride).unfold(2, size, stride).unfold(3, size, stride)
print(patches.shape)

and the error I get is:

RuntimeError: maximum size for tensor at dimension 1 is 3 but size is 128

This is the only method I've found so far. but it gives me this error

The size of your x is [1, 3, height, width] . Calling x.unfold(1, size, stride) tries to create slices of size 128 from dimension 1, which has size 3, hence it is too small to create any slice.

You don't want to create slices across dimension 1, since those are the channels of the image (RGB in this case) and they need to be kept as they are for all patches. The patches are only created across the height and width of an image.

patches = x.unfold(2, size, stride).unfold(3, size, stride)

The resulting tensor will have size [1, 3, num_vertical_slices, num_horizontal_slices, 128, 128] . You can reshape it to combine the slices to get a list of patches ie size of [1, 3, num_patches, 128, 128] :

patches = patches.reshape(1, 3, -1, size, size)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM