[英]Loss not Converging for CNN Model
圖像轉換和批處理
transform = transforms.Compose([
transforms.Resize((100,100)),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
data_set = datasets.ImageFolder(root="/content/drive/My Drive/models/pokemon/dataset",transform=transform)
train_loader = DataLoader(data_set,batch_size=10,shuffle=True,num_workers=6)
下面是我的 Model
class pokimonClassifier(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3,6,3,1)
self.conv2 = nn.Conv2d(6,18,3,1)
self.fc1 = nn.Linear(23*23*18,520)
self.fc2 = nn.Linear(520,400)
self.fc3 = nn.Linear(400,320)
self.fc4 = nn.Linear(320,149)
def forward(self,x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x,2,2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x,2,2)
x = x.view(-1,23*23*18)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = F.log_softmax(self.fc4(x), dim=1)
return x
創建 model 的實例,使用 GPU,設置標准和優化器 這是第一個設置lr = 0.001
然后后來更改為0.0001
model = pokimonClassifier()
model.to('cuda')
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr = 0.0001)
訓練數據集
for e in range(epochs):
train_crt = 0
for b,(train_x,train_y) in enumerate(train_loader):
b+=1
train_x, train_y = train_x.to('cuda'), train_y.to('cuda')
# train model
y_preds = model(train_x)
loss = criterion(y_preds,train_y)
# analysis model
predicted = torch.max(y_preds,1)[1]
correct = (predicted == train_y).sum()
train_crt += correct
# print loss and accuracy
if b%50 == 0:
print(f'Epoch {e} batch{b} loss:{loss.item()} ')
# updating weights and bais
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss.append(loss)
train_correct.append(train_crt)
我的損失值保持在 4 - 3 之間,並且沒有收斂到 0。我對深度學習非常陌生,對此我了解不多。
我使用的數據集在這里: https://www.kaggle.com/thedagger/pokemon-generation-one
非常感謝您的幫助。 謝謝你
您的網絡的問題是您應用了兩次softmax()
- 一次在fc4()
層,一次在使用nn.CrossEntropyLoss()
時。
根據官方文檔, Pytorch 在應用nn.CrossEntropyLoss()
) 時會處理softmax()
) 。
所以在你的代碼中,請改變這一行
x = F.log_softmax(self.fc4(x), dim=1)
至
x = self.fc4(x)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.