简体   繁体   中英

What is the algebraic expression for PyTorch's ConvTranspose2d's output shape?

When using PyTorch's ConvTranspose2d as such:

w = 5 # input width
h = 5 # output height
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=k, stride=s, padding=p)

What is the formula for the dimensions of the output in each channel? I tried a few examples and cannot derive the pattern. For some reason adding padding seems to shrink the output size (example starts with 5 x 5 as above):

# yields an 11 x 11 image
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) 

# yields a 7 x 7 image
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=2)

Using a larger kernel or stride both increase (expected) but not at the rate that I expected:

# yields an 11 x 11 image
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) 

# yields a 13 x 13 image
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=5, stride=2, padding=0)

# yields a 15 x 15 image
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=3, padding=0)

I'm sure there's a pretty simple math equation involving w, h, k, s, p but I can't find it in the documentation and I haven't been able to derive it myself. Normally I wouldn't ask for a math equation, but it completely affects the ability of a CNN to compile and generate the correct size. Thanks in advance!

The formula to calculate ConvTranspose2d output sizes is mentioned on the documentation page:

H_out = (H_in−1)*stride[0] − 2×padding[0] + dilation[0]×(kernel_size[0]−1) + output_padding[0] + 1

W_out = (Win−1)×stride[1] − 2×padding[1] + dilation[1]×(kernel_size[1]−1) + output_padding[1] + 1

By default, stride=1, padding=0, and output_padding=0.

For example, for

nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) 

the H_out will be

H_out = (5-1)*2 - 2*0 + 1*(3-1) + 0 + 1 = 11

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM