简体   繁体   English

在Python列表上使用nn.ModuleList会大大减慢训练速度

[英]Using nn.ModuleList over Python list dramatically slows down training

I'm training a very simple model that takes the number of hidden layers as a parameter. 我正在训练一个非常简单的模型,该模型将隐藏层的数量作为参数。 I originally stored these hidden layers in a vanilla python list [] , however when converting this list to a nn.ModuleList , training slows down dramatically by at least one order of magnitude ! 我最初将这些隐藏层存储在香草python列表[] ,但是当将此列表转换为nn.ModuleList ,训练速度至少会降低一个数量级

AdderNet AdderNet

class AdderNet(nn.Module):
    def __init__(self, num_hidden, hidden_width):
        super(AdderNet, self).__init__()
        self.relu = nn.ReLU()

        self.hiddenLayers = []
        self.inputLayer = nn.Linear(2, hidden_width)
        self.outputLayer = nn.Linear(hidden_width, 1)

        for i in range(num_hidden):
            self.hiddenLayers.append(nn.Linear(hidden_width, hidden_width))

        self.hiddenLayers = nn.ModuleList(self.hiddenLayers)  # <--- causes DRAMATIC slowdown!

    def forward(self, x):
        out = self.inputLayer(x)
        out = self.relu(out)

        for layer in self.hiddenLayers:
            out = layer(out)
            out = self.relu(out)

        return self.outputLayer(out)

Training 训练

for epoch in range(num_epochs):
    for i in range(0,len(data)):
        out = model.forward(data[i].x)
        loss = lossFunction(out, data[i].y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

That's because when using a normal python list, the parameters are not added to the model's parameter list, but when using a ModuleList, they are. 这是因为在使用普通的python列表时,参数不会添加到模型的参数列表中,而在使用ModuleList时会添加。 So, in the original scenario, you were never actually training the hidden layers, which is why it was faster. 因此,在原始方案中,您从未真正训练过隐藏层,这就是为什么它更快的原因。 (Print out model.parameters() in each case and see what happens!) (在每种情况下打印出model.parameters(),看看会发生什么!)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 Pycharm 抱怨在 pytorch 的 nn.ModuleList object 中使用 [] 运算符 - Pycharm complains about using [] operator with pytorch's nn.ModuleList object 我可以在 `nn.Sequential` 中解压 `nn.ModuleList` 吗? - Can I unpack an `nn.ModuleList` inside `nn.Sequential`? 随着时间的流逝,应用程序变慢-Java + Python - Application slows down over time - Java + Python 自适应中值滤波器在执行后不久会大幅降低 - Adaptive median filter slows down dramatically shortly after being executed Python function 因存在大列表而变慢 - Python function slows down with presence of large list 我的数据循环中的python循环随着时间的过去而变慢 - My python loop through a dataframe slows down over time 深度 Q 学习 - 训练速度明显减慢 - Deep Q Learning - training slows down significantly python for 循环中的进程变慢,即使在重置列表时也是如此 - process in python for loop slows down, even when resetting the list 使用 ModuleList,仍然得到 ValueError:优化器得到一个空参数列表 - using ModuleList, still getting ValueError: optimizer got an empty parameter list 使用 NN 训练 model output 中的 ValueError - ValueError in the training model output using NN
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM