簡體   English   中英

nn.LSTM() 收到 arguments 的無效組合

[英]nn.LSTM() received an invalid combination of arguments

我在我的代碼中使用來自 pytorch 的 lstm 來預測時間序列。 當我寫這段代碼時

class LSTM_model(nn.Module):
    def __init__(self, input_size, output_size, hidden_size,num_layers,dropout):
        super(LSTM_model,self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.seq_len = seq_len
        self.num_layers = num_layers
        self.dropout = dropout
        self.output_size = output_size
        
        # self.lstm=nn.LSTM(self.input_dim, self.hidden_dim, self.num_layers)
        self.lstm = nn.LSTM(self.input_size, self.hidden_size,self.num_layers , self.dropout, batch_first=True)
        self.fc= nn.Linear(self.hidden_size , self.output_size)
        
    def forward(self, x , hidden):
        x, hidden= self.lstm(x,hidden)
        x = self.fc(x)
        return x,hidden

但是當我使用 class 時,我收到關於來自 Z95B88F180E9EB5678E0F9EBACAC 的內部nn.LSTM() function 的x,hidden=self.lstm(x,hidden)行錯誤

<ipython-input-63-211c1442b5a7> in forward(self, x, hidden)
     15 
     16     def forward(self, x , hidden):
---> 17         x, hidden= self.lstm(x,hidden)
     18         x = self.fc(x)
     19         return x,hidden

D:\Anaconda\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

D:\Anaconda\lib\site-packages\torch\nn\modules\rnn.py in forward(self, input, hx)
    232         _impl = _rnn_impls[self.mode]
    233         if batch_sizes is None:
--> 234             result = _impl(input, hx, self._flat_weights, self.bias, self.num_layers,
    235                            self.dropout, self.training, self.bidirectional, self.batch_first)
    236         else:

TypeError: rnn_tanh() received an invalid combination of arguments - got (Tensor, Tensor, list, int, int, float, bool, bool, bool), but expected one of:
 * (Tensor data, Tensor batch_sizes, Tensor hx, tuple of Tensors params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional)
      didn't match because some of the arguments have invalid types: (Tensor, Tensor, !list!, !int!, !int!, !float!, !bool!, bool, bool)
 * (Tensor input, Tensor hx, tuple of Tensors params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first)
      didn't match because some of the arguments have invalid types: (Tensor, Tensor, !list!, !int!, int, float, bool, bool, bool)

我用這條線調用了 function

model = LSTM_model(input_size=1, output_size=1, hidden_size=128, num_layers=2, dropout=0).to(device)

它在這里被稱為

從 tqdm.auto 導入 tqdm

def loop_fn(mode, dataset, dataloader, model, criterion, optimizer,device):
    if mode =="train":
        model.train()
    elif mode =="test":
        model.eval()
    cost = 0
    for feature, target in tqdm(dataloader, desc=mode.title()):
        feature, target = feature.to(device), target.to(device)
        output , hidden = model(feature,None)
        loss = criterion(output,target)
        
        if mode =="train":
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        
        cost += loss.item() * feature.shape[0]
    cost = cost / len(dataset)
    return cost

先感謝您

我花了一段時間才發現,但是由於位置 arguments,您正在錯誤地初始化您的nn.LSTM

self.lstm = nn.LSTM(self.input_size, self.hidden_size,
    self.num_layers, self.dropout, batch_first=True)

以上將self.dropout分配給名為bias的參數:

>>> model.lstm
LSTM(1, 128, num_layers=2, bias=0, batch_first=True)

您可能想使用關鍵字 arguments 代替:

self.lstm = nn.LSTM(
    input_size=self.input_size, 
    hidden_size=self.hidden_size, 
    num_layers=self.num_layers, 
    dropout=self.dropout, 
    batch_first=True)

這將提供所需的結果:

>>> model.lstm
LSTM(1, 128, num_layers=2, batch_first=True)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM