简体   繁体   中英

How to improve cats and dogs classification using CNN with pytorch

I tried to follow the architecture of CNN in this paper, ImageNet Classification with Deep Convolutional Neural Networks ( https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf ). In this paper, they tried to classify 1000 classes, whereas I am just trying to classify 2 classes.

But, my test accuracy got stuck at 50%, and the model is not learning.

I am training with 23K images of cats and dogs, and test with 2500 images.

This is URL to my notebook https://github.com/jinglescode/workspace/blob/master/my-journey-computer-vision/codes/Cats_and_Dogs.ipynb

Could anyone advise what's wrong? What have I missed out? Willing to learn.

Normalize your data !!

For image data, you can use the recommended transform to your train and test datasets

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
train_transforms = transforms.Compose([transforms.RandomRotation(30),
                                       transforms.RandomResizedCrop(224),
                                       transforms.RandomHorizontalFlip(),
                                       transforms.ToTensor(),
                                       normalize]) 

test_transforms = transforms.Compose([transforms.Resize(255),
                                      transforms.CenterCrop(224),
                                      transforms.ToTensor(),
                                      normalize])

Additional comments:

  1. If you want to "peek" into a dataloader, instantiating next(iter(dataloader)) is not a good idea . Instead, you can access the dataset stored inside the dataloader and use its __getitem__ :

     images, labels = dataloader.dataset[0] 
  2. If your training is "stuck", the usual first reaction is to change the learning rate

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM