繁体   English   中英

使用MNIST数据集Pytorch训练SqueezeNet模型

[英]Train SqueezeNet model using MNIST dataset Pytorch

我想使用MNIST数据集而不是ImageNet数据集训练SqueezeNet 1.1模型。
我可以使用与torchvision.models.squeezenet相同的模型吗?
谢谢!

TorchVision仅为SqueezeNet体系结构提供ImageNet数据预训练模型。 但是,您可以使用MNIST数据集通过只取模型(而不是预先训练的一个),训练自己的模型torchvision.models

In [10]: import torchvision as tv

# get the model architecture only; ignore `pretrained` flag
In [11]: squeezenet11 = tv.models.squeezenet1_1()

In [12]: squeezenet11.training   
Out[12]: True

现在,您可以使用此体系结构在MNIST数据上训练模型,这应该不会花费太长时间。


要记住的一种修改是更新MNIST的类数为10。 具体来说,应将1000更改为10,并相应地调整内核和步幅。

  (classifier): Sequential(
    (0): Dropout(p=0.5)
    (1): Conv2d(512, 1000, kernel_size=(1, 1), stride=(1, 1))
    (2): ReLU(inplace)
    (3): AvgPool2d(kernel_size=13, stride=1, padding=0)
  )

这是相关的解释: finetuning_torchvision_models-squeezenet

可以对预训练的权重进行初始化,但是由于MNIST图像为28X28像素,因此您会在步幅和内核大小上遇到麻烦。 减少的可能性最大可能是在网络处于其推断层之前导致(batch_sizex1x1xchannel)特征图,然后将导致错误。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM