简体   繁体   中英

Why torch.nn.Conv2d() divides the image into 9 parts?

Sorry for the stupid question but, why torch.nnnConv2d() divides the image into 9 parts?

import torch
from torch import nn
import cv2

img = cv2.imread("image_game/eldenring 2022-12-14 19-29-50.png")
cv2.imshow('input', img)
size = img.shape #  (720, 1280, 3)
img = img.reshape((1, img.shape[2], size[0], size[1]))
img = torch.tensor(img, dtype=torch.float32)  #  torch.Size([1, 3, 720, 1280])

c1 = nn.Conv2d(3, 3, kernel_size=(3, 3), padding=2, stride=1)
img = c1(img)

size = img.shape
img = img.reshape((size[2], size[3], size[1])).detach().numpy()
cv2.imshow('output', img)
cv2.waitKey(0)

return this:

input image: 输入 output image: 输出

I want this:

动图

在此处输入图像描述

edit:

When I use

c1 = nn.Conv2d(1, 1, kernel_size=(3, 3), padding=2, stride=1)

instead

c1 = nn.Conv2d(3, 3, kernel_size=(3, 3), padding=2, stride=1)

I get what I want, but how to do it when there are more channels?

I'm sorry that the description of the question was unclear. Javier TG solved my problem

The issue is with using reshape to permute the axes -> opencv's imread gives an array of size (H, W, 3), so to get the pytorch's (1, 3, H, W) representation, transpose (in numpy) and permute (in pytorch) should be used instead. Try substituting the first reshape with img = img[None].transpose(0, 3, 1, 2), and the last reshaping with img = img[0].permute(1, 2, 0).detach().numpy() – Javier TG

I thought the problem is in the nn.Conv2d() function but i just wrong transposed data.

Corrected code:

import torch
from torch import nn
import cv2

img = cv2.imread("image_game/eldenring 2022-12-14 19-29-50.png")
cv2.imshow('input', img)  # (720, 1280, 3)
img = img[None].transpose(0, 3, 1, 2)
img = torch.as_tensor(img).float()  # torch.Size([1, 3, 720, 1280])

c1 = nn.Conv2d(3, 3, kernel_size=(3, 3), padding=1, stride=1)
img = c1(img)

img = img[0].permute(1, 2, 0).detach().numpy()  # (720, 1280, 3)
cv2.imshow('output', img)
cv2.waitKey(0)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM