简体   繁体   中英

How to construct CNN with 400 nodes hidden layer using PyTorch?

I would like to create Convolution layer with 400 nodes using PyTorch like below with condition as fully connected linear layer with 400 hidden neurons/ output is 10, flatten image to a vector for input, and use ReLU function.

When I print out x.shape it returns like torch.Size([1024, 300]) and torch.Size([1024, 10]) , and my very first layer is torch.Size([100, 3, 32, 32]) . I am confused how to construct this simple CNN and what I am missing.

class MyNet(nn.Module):
 def __init__(self):
  super(MyNet, self).__init__()
  self.relu = nn.ReLU()
  self.fc1 = torch.nn.Linear(400, 10)

  def forward(self, x):
   x = x.view(-1, 400)
   print(x.shape)
   x = self.relu(self.fc1(x))
   print(x.shape)

   return x



I think you want 400 units in the intermediate linear layer, 1 conv layer at first, and finally a linear layer with 10 units for classification.

You need a network like this:

import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self, n_channel = 3, final_conv_feature_size = (8,8), conv_filters = 32):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(n_channel, conv_filters, 3, 1) # first layer with 32 filters
        self.adapt = nn.AdaptiveMaxPool2d(final_conv_feature_size)
        self.fc1 = nn.Linear(conv_filters * final_conv_feature_size[0] * final_conv_feature_size[1] , 400)
        self.fc2 = nn.Linear(400, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.adapt(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

model = Net()
print(model)
Net(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
  (adapt): AdaptiveMaxPool2d(output_size=(8, 8))
  (fc1): Linear(in_features=2048, out_features=400, bias=True)
  (fc2): Linear(in_features=400, out_features=10, bias=True)
)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM