簡體   English   中英

如何從 pytorch 中的圖像中提取補丁?

[英]How to extract patches from an image in pytorch?

我想從補丁大小為 128、步幅為 32 的圖像中提取圖像補丁,所以我有這段代碼,但它給了我一個錯誤:

from PIL import Image 
img = Image.open("cat.jpg")
x = transforms.ToTensor()(img)

x = x.unsqueeze(0)

size = 128 # patch size
stride = 32 # patch stride
patches = x.unfold(1, size, stride).unfold(2, size, stride).unfold(3, size, stride)
print(patches.shape)

我得到的錯誤是:

RuntimeError: maximum size for tensor at dimension 1 is 3 but size is 128

這是迄今為止我發現的唯一方法。 但它給了我這個錯誤

你的x的大小是[1, 3, height, width] 調用x.unfold(1, size, stride)嘗試從尺寸為 3 的維度 1 創建大小為 128 的切片,因此它太小而無法創建任何切片。

您不想創建跨維度 1 的切片,因為這些是圖像的通道(在本例中為 RGB),並且它們需要保持原樣用於所有補丁。 僅在圖像的高度和寬度上創建補丁。

patches = x.unfold(2, size, stride).unfold(3, size, stride)

生成的張量將具有大小[1, 3, num_vertical_slices, num_horizontal_slices, 128, 128] 您可以對其進行整形以組合切片以獲得補丁列表,即[1, 3, num_patches, 128, 128]的大小:

patches = patches.reshape(1, 3, -1, size, size)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM