[英]Multilayer Perceptron for multiclass classification task
假設我有一個 MLP,它使用 ReLU 作為激活函數和CrossEntropyLoss
作為損失函數來對具有 3 個特征的樣本進行分類,這些特征屬於 10 個類別之一:我將如何實現? 目標值以 0 到 9 的數字給出。當使用CrossEntropyLoss
,目標值必須是簡單的數字,而不是一個熱向量。 但是當嘗試將 MLP 的結果轉換為單個數字時,出現索引錯誤。
MLP的標准實現:
class MLP(torch.nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MLP, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(self.hidden_size, self.output_size)
self.softmax = torch.nn.Softmax()
def forward(self, x):
hidden = self.fc1(x)
relu = self.relu(hidden)
output = self.fc2(relu)
output = self.softmax(output)
return output
以及給我一個錯誤的執行:
mlp_model = MLP(3, 10, 10)
criterion = torch.nn.CrossEntropyLoss()
mlp_model.train()
epoch = 20
for epoch in range(epoch):
y_pred = mlp_model(x_train)
y_scalar = torch.argmax(y_pred, dim=1)
loss = criterion(y_scalar, y_train) <-------------- error
loss.backward()
mlp_model.eval()
y_pred = mlp_model(x_test)
y_scalar = torch.argmax(y_pred, dim=1)
test_loss = criterion(y_scalar, y_test)
print('Test loss after Training' , test_loss.item())
y_pred_list = y_pred.tolist()
y_test_list = y_test.tolist()
from sklearn.metrics import accuracy_score
accuracy = accuracy_score(y_test_list, y_pred_list)
錯誤: IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
y_scalar 和 y_train 的輸出:
tensor([1, 3, 3, 3, 1, 1, 1, 3, 3, 1, 3, 1, 1, 3, 1, 1, 3, 3, 3, 3, 3, 3, 1, 3,
1, 3, 1, 1, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 1, 3, 3, 1, 3, 3, 1, 3, 3, 3,
3, 3, 3, 3, 1, 1, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 3, 1, 3, 1, 1,
1, 3, 3, 1, 1, 1, 3, 3, 3, 1, 3, 3, 1, 3, 3, 3, 3, 3, 1, 1, 1, 3, 3, 3,
3, 1, 3, 1, 3, 3, 3, 1, 1, 1, 3, 1, 1, 3, 3, 1, 1, 1, 1, 3, 3, 1, 3, 3,
1, 3, 1, 1, 3, 3, 1, 3, 3, 3, 1, 3, 1, 3, 3, 1, 3, 1, 1, 3, 3, 1, 1, 1,
1, 1, 3, 3, 3, 3, 3, 3, 3, 1, 3, 1, 1, 1, 3, 3, 1, 3, 3, 3, 3, 1, 3, 1,
1, 3, 3, 1, 1, 1, 3, 3, 3, 1, 3, 1, 3, 1, 1, 1, 3, 3, 1, 3, 3, 1, 3, 3,
3, 3, 3, 3, 3, 1, 3, 1, 1, 3, 1, 3, 3, 1, 1, 3, 3, 3, 3, 3, 3, 1, 3, 3,
3, 1, 3, 1, 3, 3, 3, 1, 3, 3, 3, 3, 3, 1, 3, 3, 1, 3, 3, 3, 1, 3, 3, 3,
1, 3, 1, 3, 1, 3, 3, 3, 1, 1, 3, 1, 3, 1, 1, 1, 3, 3, 3, 1, 3, 1, 3, 1,
1, 3, 3, 3, 3, 3, 1, 3, 3, 1, 3, 3, 1, 3, 3, 3, 1, 3, 3, 3, 1, 3, 1, 3,
3, 1, 3, 3, 3, 3, 3, 3, 1, 3, 1, 3, 1, 1, 1, 3, 3, 3, 3, 3, 3, 1, 3, 3,
3, 3, 3, 3, 3, 1, 1, 3, 3, 1, 3, 3, 3, 3, 1, 1, 3, 1, 1, 3, 3, 3, 1, 3,
1, 1, 1, 3, 1, 1, 3, 3, 3, 3, 1, 1, 3, 3, 3, 3, 1, 1, 1, 3, 3, 3, 3, 1,
3, 3, 3, 3, 3, 3, 1, 3, 3, 1, 3, 3, 3, 1, 3, 1, 3, 1, 1, 1, 1, 1, 3, 1,
3, 1, 1, 3, 3, 1, 3, 3, 3, 3, 1, 1, 3, 3, 3, 3, 3, 3, 1, 3, 1, 3, 3, 1,
1, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 1, 3, 3, 1, 1, 1, 3, 3, 1,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 1, 1, 3, 3, 3, 3, 3, 1, 3, 1,
3, 1, 3, 1, 1, 3, 3, 1, 3, 3, 1, 3, 1, 3, 1, 3, 3, 3, 3, 3, 3, 1, 1, 3,
1, 3, 3, 1, 3, 3, 3, 3, 3, 1, 3, 3, 3, 3, 3, 1, 1, 3, 3, 1, 3, 1, 3, 3,
1, 3, 3, 3, 3, 1, 3, 1, 1, 1, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 1,
3, 1, 3, 3, 1, 3, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3,
1, 1, 3, 3, 3, 3, 1, 1, 3, 3, 1, 1, 1, 3, 3, 3, 1, 3, 1, 1, 3, 3, 3, 3,
3, 3, 3, 3, 1, 3, 3, 1, 1, 3, 3, 3, 1, 1, 1, 3, 3, 3, 1, 1, 1, 3, 3, 1,
3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 1, 3, 3, 3, 3, 1, 3, 3, 1, 1,
3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3])
tensor([3., 4., 4., 0., 3., 2., 0., 3., 3., 2., 0., 0., 4., 3., 3., 3., 2., 3.,
1., 3., 5., 3., 4., 6., 3., 3., 6., 3., 2., 4., 3., 6., 0., 4., 2., 0.,
1., 5., 4., 4., 3., 6., 6., 4., 3., 3., 2., 5., 3., 4., 5., 3., 0., 2.,
1., 4., 6., 3., 2., 2., 0., 0., 0., 4., 2., 0., 4., 5., 2., 6., 5., 2.,
2., 2., 0., 4., 5., 6., 4., 0., 0., 0., 4., 2., 4., 1., 4., 6., 0., 4.,
2., 4., 6., 6., 0., 0., 6., 5., 0., 6., 0., 2., 1., 1., 1., 2., 6., 5.,
6., 1., 2., 2., 1., 5., 5., 5., 6., 5., 6., 5., 5., 1., 6., 6., 1., 5.,
1., 6., 5., 5., 5., 1., 5., 1., 1., 1., 1., 1., 1., 1., 4., 3., 0., 3.,
6., 6., 0., 3., 4., 0., 3., 4., 4., 1., 2., 2., 2., 3., 3., 3., 3., 0.,
4., 5., 0., 3., 4., 3., 3., 3., 2., 3., 3., 2., 2., 6., 1., 4., 3., 3.,
3., 6., 3., 3., 3., 3., 0., 4., 2., 2., 6., 5., 3., 5., 4., 0., 4., 3.,
4., 4., 3., 3., 2., 4., 0., 3., 2., 3., 3., 4., 4., 0., 3., 6., 0., 3.,
3., 4., 3., 3., 5., 2., 3., 2., 4., 1., 3., 2., 2., 3., 3., 3., 3., 5.,
1., 3., 1., 3., 5., 0., 3., 5., 0., 4., 2., 4., 2., 4., 4., 5., 4., 3.,
5., 3., 3., 4., 3., 0., 4., 5., 0., 3., 6., 2., 5., 5., 5., 3., 2., 3.,
0., 4., 5., 3., 0., 4., 0., 3., 3., 0., 0., 3., 5., 4., 4., 3., 4., 3.,
3., 2., 2., 3., 0., 3., 1., 3., 2., 3., 3., 4., 5., 2., 1., 1., 0., 0.,
1., 6., 1., 3., 3., 3., 2., 3., 3., 0., 3., 4., 1., 3., 4., 3., 2., 0.,
0., 4., 2., 3., 2., 1., 4., 6., 3., 2., 0., 3., 3., 2., 3., 4., 4., 2.,
1., 3., 5., 3., 2., 0., 4., 5., 1., 3., 3., 2., 0., 2., 4., 2., 2., 2.,
5., 4., 4., 2., 2., 0., 3., 2., 4., 4., 5., 5., 1., 0., 3., 4., 5., 3.,
4., 5., 3., 4., 3., 3., 1., 4., 3., 3., 5., 2., 3., 2., 5., 5., 4., 3.,
3., 3., 3., 1., 5., 3., 3., 2., 6., 0., 1., 3., 0., 1., 5., 3., 6., 3.,
6., 0., 3., 3., 3., 5., 4., 3., 4., 0., 5., 2., 1., 2., 4., 4., 4., 4.,
3., 3., 0., 4., 3., 0., 5., 2., 0., 5., 4., 4., 4., 3., 0., 6., 5., 2.,
4., 5., 1., 3., 5., 3., 0., 3., 5., 1., 1., 0., 3., 4., 2., 6., 2., 0.,
5., 3., 4., 6., 5., 3., 5., 0., 1., 3., 0., 5., 2., 2., 3., 5., 1., 0.,
3., 1., 4., 2., 5., 6., 4., 2., 2., 6., 0., 0., 4., 6., 3., 2., 0., 3.,
6., 1., 6., 3., 1., 3., 3., 3., 3., 2., 5., 4., 5., 5., 3., 1., 3., 3.,
4., 4., 2., 0., 2., 0., 5., 4., 0., 0., 3., 2., 2., 2., 2., 6., 4., 6.,
5., 5., 1., 0., 0., 4., 3., 3., 1., 3., 6., 6., 2., 3., 3., 3., 1., 2.,
2., 5., 4., 3., 2., 1., 2., 2., 3., 2., 3., 2., 3., 3., 0., 5., 3., 3.,
3., 4., 5., 3., 2., 1., 4., 4., 4., 4., 0., 5., 4., 1., 3., 0., 3., 4.,
6., 3., 6., 3., 3., 3., 6., 3., 4., 3., 6., 3., 0., 3., 1., 2., 5., 6.,
5., 2., 0., 2., 2., 3., 3., 0., 3., 5., 3., 4., 0., 3., 2., 4., 5., 2.,
3., 2., 2., 3., 5., 2., 0., 3., 4., 3.])```
正如評論中提到的那樣,模型內部不需要 softmax,因為nn.CrossEntropyLoss
包括它。 此外,損失的計算是在 argmax 之前完成的。 還要注意模型的輸入和輸出的形狀。 請參考以下更新。
import torch
class MLP(torch.nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MLP, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(self.hidden_size, self.output_size)
#self.softmax = torch.nn.Softmax()
def forward(self, x):
hidden = self.fc1(x)
relu = self.relu(hidden)
output = self.fc2(relu)
#output = self.softmax(output)
return output
mlp_model = MLP(3, 10, 10)
criterion = torch.nn.CrossEntropyLoss()
mlp_model.train()
epoch = 20
x_train = torch.randn(100, 3) # random 100 inputs of shape (100, 3)
y_train = torch.randint(low=0, high=10, size=(100,)) # random 100 ground truths of shape (100,)
for epoch in range(epoch):
y_pred = mlp_model(x_train)
y_scalar = torch.argmax(y_pred, dim=1)
#loss = criterion(y_scalar, y_train)# <-------------- error
loss = criterion(y_pred, y_train) # loss calculated before argmax
loss.backward().....
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.