簡體   English   中英

LSTM 總是預測相同的 class

[英]LSTM always predicts the same class

我正在嘗試使用 LSTM 解決 nlp 分類問題。 model 的代碼在這里定義:

class LSTM(nn.Module):

  def __init__(self, hidden_size, embedding_size=66 ):

      super().__init__()
      self.lstm = nn.LSTM(embedding_size, hidden_size, batch_first = True, bidirectional = True)
      self.fc = nn.Linear(2*hidden_size,2)

  def forward(self, input_seq):
      
      output, (hidden_state, cell_state) = self.lstm(input_seq)

      hidden_state = torch.cat((hidden_state[-1,:], hidden_state[-2,:]), -1)

      logits = self.fc(hidden_state)
      
      return nn.LogSoftmax(dim=1)(logits)

我用來訓練這個 model 的 function 在這里:

def train_loop(dataloader, model, loss_fn, optimizer):
    
    loss_fn = loss_fn
    size = len(dataloader.dataset)
    model.train()
    zeros = 0
    for batch, (X, y) in enumerate(dataloader):

        # Transform string into tensor
        tensor = torch.zeros(1,len(X[0]),66)
        for i in range(len(X[0])):
            tensor[0][i][ctoi[X[0][i]]] = 1

        pred = model(tensor)

        target = torch.zeros(2, dtype=torch.long)
        target[y] = 1
        
        if batch % 100 == 0:
            print(pred.squeeze(), target)
        loss = loss_fn(pred.squeeze(), target)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if pred.squeeze().argmax() == 0:
            zeros += 1

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

    print(f'In trainning predicted {zeros} zeroes out of {size} samples')

X 仍然是字符串,這就是為什么我需要在通過 model 運行之前將它們轉換為張量。 y 是 0 或 1(因為它是一個二元分類問題),我需要將其轉換為形狀為 (2,) 的張量以通過損失 function。

出於某種原因,我不斷得到為每個輸入預測的相同 class。 這些類甚至沒有那么不平衡(~45% 到 55%),我已經嘗試在損失 function 中更改類的權重,但沒有任何改進,它要么收斂於預測總是 0 要么總是 1。大部分它收斂到始終預測為 0 的時間,這更沒有意義,因為通常發生的情況是 class 0 的樣本少於 class 1 的樣本。

由於您正在訓練二進制分類 model,因此您的 output 暗淡應該為 1(對應於單個概率 P(y|x))。 這意味着您從數據加載器中檢索的y應該是損失 function 中使用的 y(假設交叉熵損失)。 因此,預測的 class 為y_hat = round(pred) (即預測值 >= 0.5)。

為了清楚起見,如果 one-hot 編碼發生在您的數據集中(在__getitem____iter__中),則遵循您的邏輯會容易得多。 還值得注意的是,您不使用嵌入,因此分類器的代碼有點誤導。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM