简体   繁体   中英

Translate Keras functional API to PyTorch nn.Module - Conv2d

I'm trying to translate the following Inception code from tutorial in Keras functional API ( link ) to PyTorch nn.Module :

    def conv_module(x, K, kX, kY, stride, chanDim, padding="same"):
        # define a CONV => BN => RELU pattern
        x = Conv2D(K, (kX, kY), strides=stride, padding=padding)(x)
        x = BatchNormalization(axis=chanDim)(x)
        x = Activation("relu")(x)
        # return the block
        return x


    def inception_module(x, numK1x1, numK3x3, chanDim):
        # define two CONV modules, then concatenate across the
        # channel dimension
        conv_1x1 = conv_module(x, numK1x1, 1, 1, (1, 1), chanDim)
        conv_3x3 = conv_module(x, numK3x3, 3, 3, (1, 1), chanDim)
        x = concatenate([conv_1x1, conv_3x3], axis=chanDim)
        # return the block
        return x

I'm having trouble translating the Conv2D . If I understand correctly:

  1. There is no in_features in Keras - how should I represent it in PyTorch?
  2. Keras filters is PyTorch out_features
  3. kernel_size , stride and padding are the same (maybe a few options for padding are called differently)

Do I understand this correctly? If so, what should I do with in_features ? My code so far:

class BasicConv2d(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: int,
            stride: int
    ) -> None:
        super().__init__()
        self.conv = nn.Conv2d(in_channels,
                              out_channels,
                              kernel_size=kernel_size,
                              stride=stride)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
        self.relu = nn.ReLU()

    def forward(self, x: Tensor) -> Tensor:
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class Inception(nn.Module):
    def __init__(
            self,
            in_channels: int,
            num_1x1_filters: int,
            num_3x3_filters: int,
    ) -> None:
        super().__init__()
        # how to fill this further?
        self.conv_1d = BasicConv2d(
            num_1x1_filters,
            )

You're correct for the most part. The in_channels parameter in Con2d corresponds to the no. of output channels from the previous layer. If Conv2d is the first layer, the in_channels correspond to the no. of channels in your image. It will be 1 for a Grayscale image and 3 for an RGB image.

But I'm not sure how you could concat the two BasicConv2d outputs.

Fixing batch_size as 1, assume that the image size is 256*256 and out_channels for conv1x1 is 64. This would output a tensor of shape torch.Size([1, 64, 256, 256]) . Assuming out_channels of the conv3x3 as 32, this layer would output a tensor of shape torch.Size([1, 32, 254, 254]) . We will not be able to concat these two tensors without some trick, such as using padding=1 for the conv3x3 alone as this would produce an output of shape torch.Size([1, 32, 256, 256]) and therefore we would be able to concat.

Your implementation of BasicConv2d is fine, here is the code of Inception module.

class Inception(nn.Module):
def __init__(
        self,
        in_channels: int,
        num_1x1_filters: int,
        num_3x3_filters: int,
) -> None:
    super().__init__()
    # how to fill this further?
    self.conv1 = BasicConv2d(in_channels, num_1x1_filters, 1,1)
    self.conv3 = BasicConv2d(in_channels, num_3x3_filters, 3,1)
def forward(self,x):
    conv1_out = self.conv1(x)
    conv3_out = self.conv3(x)
    x = torch.cat([conv1_out, conv3_out],)
    return x

You need define two basic conv layers, and use them in the forward pass with same input separately. As @planet_pluto pointed, you can't concatenate two feature maps have different size. you can choose a better stride , padding to construct two feature maps with same size, alternatively, do upsampling or downsampling before you concatenate them.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM