This repository contains codes for training and evaluating the Mask Classifier model, which is mostly implemented in Python 3 and Keras framework.
When training and testing the model, we mostly use SSH to detect and crop face images.
Please read our paper at here for further details on preparing the datasets as well as training and testing the model.
This project was supported by VinAI research.
Our team consists of four members including me (Thang Pham), Bao Tran, Duy Pham, and Long Nguyen.
This project was performed while the authors interned at VinAI research.
Our team has tried different network architectures and trained on different frameworks such as PyTorch, Keras, and Caffe.
We observed that deploying ResNet50 pretrained network brought us high accuracy model as well as fast training time.
During the experiment, the Mask Classifer model can achieve about 96.5% accuracy and a performance of 10-15 FPS on a Nvidia's GeForce GTX Titan X machine.
- Clone the repository:
git clone https://github.com/aome510/Mask-Classifier.git
- Setup SSH:
- Install python requirements:
pip install -r requirements.txt
To run the demo, first you need to download the pre-trained mask classifier model from here and then save it into a folder named model/
.
After downloading the model, you can run demo on webcam:
python demo.py
If you want to demo on videos, you can download our demo videos from here and then save it into a folder named data/videos/
. After that, you can modify demo.py
to run demo on videos.
-
CelebA dataset:
You can download CelebA dataset from here and then save it into a folder named
data/celebA/
-
WiderFace dataset:
You can download WiderFace dataset from here and then save it into a folder named
data/WiderFace/
-
MAFA dataset:
You can download MAFA dataset from here and then save it into a folder named
data/MAFA/
-
Mask Classifier dataset (our dataset):
You can download Mask Classifier dataset from here and then save it into a folder name
data/mask_classifier/
After downloading all the datasets listed above you can run:
python gen_data.py
to combine the datasets and split them for training and cross-validating the model.
After preparing the combined dataset, you can start training the model:
python train.py
By default, we use resnet50 network to train the model. You can change to reception_resnet_v2 network by modifying train.py
.
Please read this if you want to train with diffrent network architectures.