简体   繁体   中英

Using nn.ModuleList over Python list dramatically slows down training

I'm training a very simple model that takes the number of hidden layers as a parameter. I originally stored these hidden layers in a vanilla python list [] , however when converting this list to a nn.ModuleList , training slows down dramatically by at least one order of magnitude !

AdderNet

class AdderNet(nn.Module):
    def __init__(self, num_hidden, hidden_width):
        super(AdderNet, self).__init__()
        self.relu = nn.ReLU()

        self.hiddenLayers = []
        self.inputLayer = nn.Linear(2, hidden_width)
        self.outputLayer = nn.Linear(hidden_width, 1)

        for i in range(num_hidden):
            self.hiddenLayers.append(nn.Linear(hidden_width, hidden_width))

        self.hiddenLayers = nn.ModuleList(self.hiddenLayers)  # <--- causes DRAMATIC slowdown!

    def forward(self, x):
        out = self.inputLayer(x)
        out = self.relu(out)

        for layer in self.hiddenLayers:
            out = layer(out)
            out = self.relu(out)

        return self.outputLayer(out)

Training

for epoch in range(num_epochs):
    for i in range(0,len(data)):
        out = model.forward(data[i].x)
        loss = lossFunction(out, data[i].y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

That's because when using a normal python list, the parameters are not added to the model's parameter list, but when using a ModuleList, they are. So, in the original scenario, you were never actually training the hidden layers, which is why it was faster. (Print out model.parameters() in each case and see what happens!)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM