簡體   English   中英

用於多類分類任務的多層感知器

[英]Multilayer Perceptron for multiclass classification task

假設我有一個 MLP,它使用 ReLU 作為激活函數和CrossEntropyLoss作為損失函數來對具有 3 個特征的樣本進行分類,這些特征屬於 10 個類別之一:我將如何實現? 目標值以 0 到 9 的數字給出。當使用CrossEntropyLoss ,目標值必須是簡單的數字,而不是一個熱向量。 但是當嘗試將 MLP 的結果轉換為單個數字時,出現索引錯誤。

MLP的標准實現:

class MLP(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__() 
        self.input_size = input_size
        self.hidden_size  = hidden_size
        self.output_size = output_size
        self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(self.hidden_size, self.output_size)
        self.softmax = torch.nn.Softmax()
    
    def forward(self, x):
        hidden = self.fc1(x)
        relu = self.relu(hidden)
        output = self.fc2(relu)
        output = self.softmax(output)
        return output

以及給我一個錯誤的執行:

mlp_model = MLP(3, 10, 10)
criterion = torch.nn.CrossEntropyLoss()
mlp_model.train()
epoch = 20
for epoch in range(epoch):
    y_pred = mlp_model(x_train)
    y_scalar = torch.argmax(y_pred, dim=1)

    loss = criterion(y_scalar, y_train) <-------------- error

    loss.backward()
mlp_model.eval()
y_pred = mlp_model(x_test)
y_scalar = torch.argmax(y_pred, dim=1)
test_loss = criterion(y_scalar, y_test) 
print('Test loss after Training' , test_loss.item())

y_pred_list = y_pred.tolist()
y_test_list = y_test.tolist()

from sklearn.metrics import accuracy_score
accuracy = accuracy_score(y_test_list, y_pred_list)

錯誤: IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

y_scalar 和 y_train 的輸出:

tensor([1, 3, 3, 3, 1, 1, 1, 3, 3, 1, 3, 1, 1, 3, 1, 1, 3, 3, 3, 3, 3, 3, 1, 3,
        1, 3, 1, 1, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 1, 3, 3, 1, 3, 3, 1, 3, 3, 3,
        3, 3, 3, 3, 1, 1, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 3, 1, 3, 1, 1,
        1, 3, 3, 1, 1, 1, 3, 3, 3, 1, 3, 3, 1, 3, 3, 3, 3, 3, 1, 1, 1, 3, 3, 3,
        3, 1, 3, 1, 3, 3, 3, 1, 1, 1, 3, 1, 1, 3, 3, 1, 1, 1, 1, 3, 3, 1, 3, 3,
        1, 3, 1, 1, 3, 3, 1, 3, 3, 3, 1, 3, 1, 3, 3, 1, 3, 1, 1, 3, 3, 1, 1, 1,
        1, 1, 3, 3, 3, 3, 3, 3, 3, 1, 3, 1, 1, 1, 3, 3, 1, 3, 3, 3, 3, 1, 3, 1,
        1, 3, 3, 1, 1, 1, 3, 3, 3, 1, 3, 1, 3, 1, 1, 1, 3, 3, 1, 3, 3, 1, 3, 3,
        3, 3, 3, 3, 3, 1, 3, 1, 1, 3, 1, 3, 3, 1, 1, 3, 3, 3, 3, 3, 3, 1, 3, 3,
        3, 1, 3, 1, 3, 3, 3, 1, 3, 3, 3, 3, 3, 1, 3, 3, 1, 3, 3, 3, 1, 3, 3, 3,
        1, 3, 1, 3, 1, 3, 3, 3, 1, 1, 3, 1, 3, 1, 1, 1, 3, 3, 3, 1, 3, 1, 3, 1,
        1, 3, 3, 3, 3, 3, 1, 3, 3, 1, 3, 3, 1, 3, 3, 3, 1, 3, 3, 3, 1, 3, 1, 3,
        3, 1, 3, 3, 3, 3, 3, 3, 1, 3, 1, 3, 1, 1, 1, 3, 3, 3, 3, 3, 3, 1, 3, 3,
        3, 3, 3, 3, 3, 1, 1, 3, 3, 1, 3, 3, 3, 3, 1, 1, 3, 1, 1, 3, 3, 3, 1, 3,
        1, 1, 1, 3, 1, 1, 3, 3, 3, 3, 1, 1, 3, 3, 3, 3, 1, 1, 1, 3, 3, 3, 3, 1,
        3, 3, 3, 3, 3, 3, 1, 3, 3, 1, 3, 3, 3, 1, 3, 1, 3, 1, 1, 1, 1, 1, 3, 1,
        3, 1, 1, 3, 3, 1, 3, 3, 3, 3, 1, 1, 3, 3, 3, 3, 3, 3, 1, 3, 1, 3, 3, 1,
        1, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 1, 3, 3, 1, 1, 1, 3, 3, 1,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 1, 1, 3, 3, 3, 3, 3, 1, 3, 1,
        3, 1, 3, 1, 1, 3, 3, 1, 3, 3, 1, 3, 1, 3, 1, 3, 3, 3, 3, 3, 3, 1, 1, 3,
        1, 3, 3, 1, 3, 3, 3, 3, 3, 1, 3, 3, 3, 3, 3, 1, 1, 3, 3, 1, 3, 1, 3, 3,
        1, 3, 3, 3, 3, 1, 3, 1, 1, 1, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 1,
        3, 1, 3, 3, 1, 3, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3,
        1, 1, 3, 3, 3, 3, 1, 1, 3, 3, 1, 1, 1, 3, 3, 3, 1, 3, 1, 1, 3, 3, 3, 3,
        3, 3, 3, 3, 1, 3, 3, 1, 1, 3, 3, 3, 1, 1, 1, 3, 3, 3, 1, 1, 1, 3, 3, 1,
        3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 1, 3, 3, 3, 3, 1, 3, 3, 1, 1,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3])
tensor([3., 4., 4., 0., 3., 2., 0., 3., 3., 2., 0., 0., 4., 3., 3., 3., 2., 3.,
        1., 3., 5., 3., 4., 6., 3., 3., 6., 3., 2., 4., 3., 6., 0., 4., 2., 0.,
        1., 5., 4., 4., 3., 6., 6., 4., 3., 3., 2., 5., 3., 4., 5., 3., 0., 2.,
        1., 4., 6., 3., 2., 2., 0., 0., 0., 4., 2., 0., 4., 5., 2., 6., 5., 2.,
        2., 2., 0., 4., 5., 6., 4., 0., 0., 0., 4., 2., 4., 1., 4., 6., 0., 4.,
        2., 4., 6., 6., 0., 0., 6., 5., 0., 6., 0., 2., 1., 1., 1., 2., 6., 5.,
        6., 1., 2., 2., 1., 5., 5., 5., 6., 5., 6., 5., 5., 1., 6., 6., 1., 5.,
        1., 6., 5., 5., 5., 1., 5., 1., 1., 1., 1., 1., 1., 1., 4., 3., 0., 3.,
        6., 6., 0., 3., 4., 0., 3., 4., 4., 1., 2., 2., 2., 3., 3., 3., 3., 0.,
        4., 5., 0., 3., 4., 3., 3., 3., 2., 3., 3., 2., 2., 6., 1., 4., 3., 3.,
        3., 6., 3., 3., 3., 3., 0., 4., 2., 2., 6., 5., 3., 5., 4., 0., 4., 3.,
        4., 4., 3., 3., 2., 4., 0., 3., 2., 3., 3., 4., 4., 0., 3., 6., 0., 3.,
        3., 4., 3., 3., 5., 2., 3., 2., 4., 1., 3., 2., 2., 3., 3., 3., 3., 5.,
        1., 3., 1., 3., 5., 0., 3., 5., 0., 4., 2., 4., 2., 4., 4., 5., 4., 3.,
        5., 3., 3., 4., 3., 0., 4., 5., 0., 3., 6., 2., 5., 5., 5., 3., 2., 3.,
        0., 4., 5., 3., 0., 4., 0., 3., 3., 0., 0., 3., 5., 4., 4., 3., 4., 3.,
        3., 2., 2., 3., 0., 3., 1., 3., 2., 3., 3., 4., 5., 2., 1., 1., 0., 0.,
        1., 6., 1., 3., 3., 3., 2., 3., 3., 0., 3., 4., 1., 3., 4., 3., 2., 0.,
        0., 4., 2., 3., 2., 1., 4., 6., 3., 2., 0., 3., 3., 2., 3., 4., 4., 2.,
        1., 3., 5., 3., 2., 0., 4., 5., 1., 3., 3., 2., 0., 2., 4., 2., 2., 2.,
        5., 4., 4., 2., 2., 0., 3., 2., 4., 4., 5., 5., 1., 0., 3., 4., 5., 3.,
        4., 5., 3., 4., 3., 3., 1., 4., 3., 3., 5., 2., 3., 2., 5., 5., 4., 3.,
        3., 3., 3., 1., 5., 3., 3., 2., 6., 0., 1., 3., 0., 1., 5., 3., 6., 3.,
        6., 0., 3., 3., 3., 5., 4., 3., 4., 0., 5., 2., 1., 2., 4., 4., 4., 4.,
        3., 3., 0., 4., 3., 0., 5., 2., 0., 5., 4., 4., 4., 3., 0., 6., 5., 2.,
        4., 5., 1., 3., 5., 3., 0., 3., 5., 1., 1., 0., 3., 4., 2., 6., 2., 0.,
        5., 3., 4., 6., 5., 3., 5., 0., 1., 3., 0., 5., 2., 2., 3., 5., 1., 0.,
        3., 1., 4., 2., 5., 6., 4., 2., 2., 6., 0., 0., 4., 6., 3., 2., 0., 3.,
        6., 1., 6., 3., 1., 3., 3., 3., 3., 2., 5., 4., 5., 5., 3., 1., 3., 3.,
        4., 4., 2., 0., 2., 0., 5., 4., 0., 0., 3., 2., 2., 2., 2., 6., 4., 6.,
        5., 5., 1., 0., 0., 4., 3., 3., 1., 3., 6., 6., 2., 3., 3., 3., 1., 2.,
        2., 5., 4., 3., 2., 1., 2., 2., 3., 2., 3., 2., 3., 3., 0., 5., 3., 3.,
        3., 4., 5., 3., 2., 1., 4., 4., 4., 4., 0., 5., 4., 1., 3., 0., 3., 4.,
        6., 3., 6., 3., 3., 3., 6., 3., 4., 3., 6., 3., 0., 3., 1., 2., 5., 6.,
        5., 2., 0., 2., 2., 3., 3., 0., 3., 5., 3., 4., 0., 3., 2., 4., 5., 2.,
        3., 2., 2., 3., 5., 2., 0., 3., 4., 3.])```

正如評論中提到的那樣,模型內部不需要 softmax,因為nn.CrossEntropyLoss包括它。 此外,損失的計算是在 argmax 之前完成的。 還要注意模型的輸入和輸出的形狀。 請參考以下更新。

import torch
class MLP(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__() 
        self.input_size = input_size
        self.hidden_size  = hidden_size
        self.output_size = output_size
        self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(self.hidden_size, self.output_size)
        #self.softmax = torch.nn.Softmax()
    
    def forward(self, x):
        hidden = self.fc1(x)
        relu = self.relu(hidden)
        output = self.fc2(relu)
        #output = self.softmax(output)
        return output

mlp_model = MLP(3, 10, 10)
criterion = torch.nn.CrossEntropyLoss()
mlp_model.train()
epoch = 20
x_train = torch.randn(100, 3) # random 100 inputs of shape (100, 3)
y_train = torch.randint(low=0, high=10, size=(100,)) # random 100 ground truths of shape (100,)
for epoch in range(epoch):
    y_pred = mlp_model(x_train)
    y_scalar = torch.argmax(y_pred, dim=1)

    #loss = criterion(y_scalar, y_train)# <-------------- error
    loss = criterion(y_pred, y_train) # loss calculated before argmax

    loss.backward().....

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM