简体   繁体   English

mnist手写数字识别损失没变

[英]mnist handwritten digit recognition loss has not changed

I am using pytorch to implement mnist handwritten digit recognition, but my loss is unchanged from the beginning, why?我是用pytorch实现mnist手写数字识别,但是我的loss从一开始就没有变化,为什么? Here is my code:这是我的代码:

import gzip
import pickle
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F

f = gzip.open("./mnist.pkl.gz","rb")
train_data, val_data, test_data = pickle.load(f,encoding='latin1')
f.close()

train_data_img = torch.tensor(train_data[0].reshape(250,200,784))
train_data_ans = torch.tensor(train_data[1].reshape(250,200))

class Net(nn.Module):

    def __init__(self):
        super(Net,self).__init__()
        self.fc1 = nn.Linear(784, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 10)

    def forward(self,x):
        x = F.leaky_relu(self.fc1(x))
        x = F.leaky_relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()
optimizer = optim.SGD(net.parameters(),lr=0.105,momentum=0.98)

losses = []
loss = 1
for epoch in range(5):
    for data,ans in zip(train_data_img,train_data_ans):
        out = net(data)
        ans = F.one_hot(ans)
        loss = F.mse_loss(out, ans)
        optimizer.zero_grad()
        loss.backword()
        optimizer.step()
    losses.append(loss.item())

print(losses)
xlabel = np.linspace(0,len(losses),len(losses))
plt.plot(xlabel,losses)
plt.show()

I find that each input corresponds to the same output. What's the problem?我发现每个输入对应的都是同一个output,有什么问题吗? I just learned this a little bit.我只是学到了一点点。

For classification you should use an activation on the last layer, like softmax:对于分类,您应该在最后一层使用激活,例如 softmax:

def forward(self,x):
    x = F.leaky_relu(self.fc1(x))
    x = F.leaky_relu(self.fc2(x))
    x = self.fc3(x)
    return F.softmax(x, dim=1)

Also try more epochs, lower the learning rate to maybe 0.001 and fix the loss.backword() typo.也尝试更多的时期,将学习率降低到0.001并修复loss.backword()错字。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM