Simple CNN for Image Classification

Lahiru Rathnayake
3 min readApr 11, 2022

--

Image by Author.

Introduction

Image classification is the most common computer vision problem where an algorithm analyzes the image and classifies the object in it. There are various kinds of models that have been successful at image classification tasks. The Simple CNN library was developed as a wrapper around PyTorch transfer learning with pre-trained models to make Image Classification easier. Also, it includes ONNX inference for the trained models. In this post, we’ll walk through how to train and infer a custom dataset with the use of Simple CNN.

Data Preparation

For this work, the famous Cats vs Dogs data set was used as the primary data set. Please Download the Dogs vs. Cats Redux: Kernels Edition from the Kaggle and split it into train and val folders as shown in the following structure.

├── Cats_vs_Dogs
├── train
│ │───── Cat
│ │ ├── cat.0.jpg
│ │ ├── cat.1.jpg
│ │ ├── cat.2.jpg
│ │ ├── .........
│ │ └── cat.500.jpg
│ │
│ └───── Dog
│ ├── dog.0.jpg
│ ├── dog.1.jpg
│ ├── dog.2.jpg
│ ├── .........
│ └── dog.500.jpg

└── val
│───── Cat
│ ├── cat.501.jpg
│ ├── cat.502.jpg
│ ├── cat.503.jpg
│ ├── .........
│ └── cat.600.jpg

└───── Dog
├── dog.501.jpg
├── dog.502.jpg
├── dog.503.jpg
├── .........
└── dog.600.jpg

Install

Clone the repo and install requirements.txt in a Python environment

git clone https://github.com/LahiRumesh/simple_cnn.git
cd simple_cnn
pip install -r requirements.txt

Training

Use the config.py for the change the model configurations and hyperparameters. There are several pre-trained models available in Simple CNN. In here, I have used the resnet18 with pre-trained weights for Cat vs Dogs classification task. If you need, you can use any other classification model which is available in Simple CNN and also check out PyTorch documentation for more information about hyperparameters.

cfg.data_dir = 'Data/Images/Cats_vs_Dogs' 
cfg.device = '0'
cfg.image_size = 224
cfg.batch_size = 8
cfg.epochs = 20
cfg.model = 'resnet18'
cfg.pretrained = True

Run cnn_train.py to start the training, and all the logs will be saved in wandb, If you are not logged into your wandb account, please log in to the account before the training process. During the training process, it calculates loss and the train and validation accuracy for each epoch and it will log into wandb.

python cnn_train.py

ONNX weight files and the class file will be generated in the “models/Cats_vs_Dogs” folder for each training experiment with the model name.

loss graphs for cat vs dogs
accuracy graphs for cat vs dogs

After the training process, model accuracy manages to reach 0.9650 for the validation data set.

Inference

After the training process, you can use the saved ONNX weight file for the inference with the use of cnn_inference.py. Execute the following command for the inference process.

python cnn_inference.py --model_path=models/Cat_vs_Dogs/cats_vs_dogs_resnet18_exp_1.onnx --class_path=models/Cat_vs_Dogs/classes.txt --img_path=test1.jpg --image_size=224

If you want to get better results, you can try out different models with hyperparameters.

Reference

--

--

Lahiru Rathnayake
Lahiru Rathnayake

Written by Lahiru Rathnayake

AI researcher with a passion centered on Machine Learning and Computer Vision.

No responses yet