[英]Train SqueezeNet model using MNIST dataset Pytorch
我想使用MNIST数据集而不是ImageNet数据集训练SqueezeNet 1.1模型。
我可以使用与torchvision.models.squeezenet相同的模型吗?
谢谢!
TorchVision仅为SqueezeNet体系结构提供ImageNet数据预训练模型。 但是,您可以使用MNIST数据集通过只取模型(而不是预先训练的一个),训练自己的模型torchvision.models
。
In [10]: import torchvision as tv
# get the model architecture only; ignore `pretrained` flag
In [11]: squeezenet11 = tv.models.squeezenet1_1()
In [12]: squeezenet11.training
Out[12]: True
现在,您可以使用此体系结构在MNIST数据上训练模型,这应该不会花费太长时间。
要记住的一种修改是更新MNIST的类数为10。 具体来说,应将1000更改为10,并相应地调整内核和步幅。
(classifier): Sequential(
(0): Dropout(p=0.5)
(1): Conv2d(512, 1000, kernel_size=(1, 1), stride=(1, 1))
(2): ReLU(inplace)
(3): AvgPool2d(kernel_size=13, stride=1, padding=0)
)
可以对预训练的权重进行初始化,但是由于MNIST图像为28X28像素,因此您会在步幅和内核大小上遇到麻烦。 减少的可能性最大可能是在网络处于其推断层之前导致(batch_sizex1x1xchannel)特征图,然后将导致错误。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.