繁体   English   中英

Pytorch nn.Module 的子类没有属性“参数”

[英]Pytorch Subclass of nn.Module has no attribute 'parameters'

Python 版本:Python 3.8.5
Pytorch 版本:'1.6.0'

我正在定义 LSTM,它是 nn.Module 的子类。 我正在尝试创建优化器,但出现以下错误: torch.nn.modules.module.ModuleAttributeError: 'LSTM' object has no attribute 'paramters'

我有两个代码文件,train.py 和 lstm_class.py(包含 LSTM 类)。 我将尝试制作一个最小的工作示例,如果有任何其他信息有帮助,请告诉我。



lstm_class.py 中的代码:

import torch.nn as nn

class LSTM(nn.Module):

    def __init__(self, vocab_size, embedding_dim, hidden_dim, n_layers, drop_prob=0.2):
        super(LSTM, self).__init__()

        # network size parameters
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim


        # the layers of the network
        self.embedding = nn.Embedding(self.vocab_size, self.embedding_dim)
        self.lstm = nn.LSTM(self.embedding_dim, self.hidden_dim, self.n_layers, dropout=drop_prob, batch_first=True)
        self.dropout = nn.Dropout(drop_prob)
        self.fc = nn.Linear(self.hidden_dim, self.vocab_size)



    def forward(self, input, hidden):
        # Defines forward pass, probably isn't relevant

    def init_hidden(self, batch_size):
        #Initializes hidden state, probably isn't relevant

train.py 中的代码

import torch
import torch.optim
import torch.nn as nn
import lstm_class

vocab_size = 1000
embedding_dim = 256
hidden_dim = 256
n_layers = 2

net = lstm_class.LSTM(vocab_size, embedding_dim, hidden_dim, n_layers)
optimizer = torch.optim.Adam(net.paramters(), lr=learning_rate) 

我在上面写的最后一行收到错误。 完整的错误信息:

Traceback (most recent call last):
  File "train.py", line 58, in <module>
    optimizer = torch.optim.Adam(net.paramters(), lr=learning_rate)
  File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 771, in __getattr__
    raise ModuleAttributeError("'{}' object has no attribute '{}'".format(
torch.nn.modules.module.ModuleAttributeError: 'LSTM' object has no attribute 'paramters'

任何有关如何解决此问题的提示将不胜感激。 同样如上所述,让我知道是否还有其他相关信息。 谢谢

这不是net.paramters() ,而是net.parameters() :)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM