Simple CNN for Image Classification
Introduction
Image classification is the most common computer vision problem where an algorithm analyzes the image and classifies the object in it. There are various kinds of models that have been successful at image classification tasks. The Simple CNN library was developed as a wrapper around PyTorch transfer learning with pre-trained models to make Image Classification easier. Also, it includes ONNX inference for the trained models. In this post, we’ll walk through how to train and infer a custom dataset with the use of Simple CNN.
Data Preparation
For this work, the famous Cats vs Dogs data set was used as the primary data set. Please Download the Dogs vs. Cats Redux: Kernels Edition from the Kaggle and split it into train and val folders as shown in the following structure.
├── Cats_vs_Dogs
├── train
│ │───── Cat
│ │ ├── cat.0.jpg
│ │ ├── cat.1.jpg
│ │ ├── cat.2.jpg
│ │ ├── .........
│ │ └── cat.500.jpg
│ │
│ └───── Dog
│ ├── dog.0.jpg
│ ├── dog.1.jpg
│ ├── dog.2.jpg
│ ├── .........
│ └── dog.500.jpg
│
└── val
│───── Cat
│ ├── cat.501.jpg
│ ├── cat.502.jpg
│ ├── cat.503.jpg
│ ├── .........
│ └── cat.600.jpg
│
└───── Dog
├── dog.501.jpg
├── dog.502.jpg
├── dog.503.jpg
├── .........
└── dog.600.jpg
Install
Clone the repo and install requirements.txt in a Python environment
git clone https://github.com/LahiRumesh/simple_cnn.git
cd simple_cnn
pip install -r requirements.txt
Training
Use the config.py for the change the model configurations and hyperparameters. There are several pre-trained models available in Simple CNN. In here, I have used the resnet18 with pre-trained weights for Cat vs Dogs classification task. If you need, you can use any other classification model which is available in Simple CNN and also check out PyTorch documentation for more information about hyperparameters.
cfg.data_dir = 'Data/Images/Cats_vs_Dogs'
cfg.device = '0'
cfg.image_size = 224
cfg.batch_size = 8
cfg.epochs = 20 cfg.model = 'resnet18'
cfg.pretrained = True
Run cnn_train.py to start the training, and all the logs will be saved in wandb, If you are not logged into your wandb account, please log in to the account before the training process. During the training process, it calculates loss and the train and validation accuracy for each epoch and it will log into wandb.
python cnn_train.py
ONNX weight files and the class file will be generated in the “models/Cats_vs_Dogs” folder for each training experiment with the model name.
After the training process, model accuracy manages to reach 0.9650 for the validation data set.
Inference
After the training process, you can use the saved ONNX weight file for the inference with the use of cnn_inference.py. Execute the following command for the inference process.
python cnn_inference.py --model_path=models/Cat_vs_Dogs/cats_vs_dogs_resnet18_exp_1.onnx --class_path=models/Cat_vs_Dogs/classes.txt --img_path=test1.jpg --image_size=224
If you want to get better results, you can try out different models with hyperparameters.