[英]Why the pytorch classification model not learning?
我創建了一個簡單的 pytorch 分類 model ,其中包含使用 sklearns make_classification
生成的示例數據集。 即使經過數千個 epoch 的訓練,model 的准確率也徘徊在 30% 到 40% 之間。 在訓練期間,損失值波動很大。 我想知道為什么這個 model 不學習,是否由於代碼中的一些邏輯錯誤。
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
X,y = make_classification(n_features=15,n_classes=5,n_informative=4)
DEVICE = torch.device('cuda')
epochs = 5000
class CustomDataset(Dataset):
def __init__(self,X,y):
self.X = torch.from_numpy(X)
self.y = torch.from_numpy(y)
def __len__(self):
return len(self.X)
def __getitem__(self, index):
X = self.X[index]
y = self.y[index]
return (X,y)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(15,10)
self.l2 = nn.Linear(10,5)
self.relu = nn.ReLU()
def forward(self,x):
x = self.l1(x)
x = self.relu(x)
x = self.l2(x)
x = self.relu(x)
return x
model = Model().double().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_function = nn.CrossEntropyLoss()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
train_data = CustomDataset(X_train,y_train)
test_data = CustomDataset(X_test,y_test)
trainloader = DataLoader(train_data, batch_size=32, shuffle=True)
testloader = DataLoader(test_data, batch_size=32, shuffle=True)
for i in range(epochs):
for (x,y) in trainloader:
x = x.to(DEVICE)
y = y.to(DEVICE)
optimizer.zero_grad()
output = model(x)
loss = loss_function(output,y)
loss.backward()
optimizer.step()
if i%200==0:
print("epoch: ",i," Loss: ",loss.item())
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
for x, y in testloader:
# calculate outputs by running x through the network
outputs = model(x.to(DEVICE)).to(DEVICE)
# the class with the highest energy is what we choose as prediction
_, predicted = torch.max(outputs.data, 1)
total += y.size(0)
correct += (predicted == y.to(DEVICE)).sum().item()
print(f'Accuracy of the network on the test data: {100 * correct // total} %')
您不應該在 output 層上使用 ReLU 激活。 通常 softmax 激活用於最后一層的多 class 分類,或者直接將 logits 饋送到損失 function 而不顯式添加 softmax 激活層。
嘗試從最后一層移除 ReLU 激活。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.