New: Visualize PBA and applied augmentations with the notebook pba.ipynb
!
Now with Python 3 support.
Population Based Augmentation (PBA) is a algorithm that quickly and efficiently learns data augmentation functions for neural network training. PBA matches state-of-the-art results on CIFAR with one thousand times less compute, enabling researchers and practitioners to effectively learn new augmentation policies using a single workstation GPU.
This repository contains code for the work "Population Based Augmentation: Efficient Learning of Augmentation Schedules" (http://arxiv.org/abs/1905.05393) in TensorFlow and Python. It includes training of models with the reported augmentation schedules and discovery of new augmentation policy schedules.
See below for a visualization of our augmentation strategy.
Code supports Python 2 and 3.
pip install -r requirements.txt
bash datasets/cifar10.sh
bash datasets/cifar100.sh
Dataset | Model | Test Error (%) |
---|---|---|
CIFAR-10 | Wide-ResNet-28-10 | 2.58 |
Shake-Shake (26 2x32d) | 2.54 | |
Shake-Shake (26 2x96d) | 2.03 |
|
Shake-Shake (26 2x112d) | 2.03 | |
PyramidNet+ShakeDrop | 1.46 | |
Reduced CIFAR-10 | Wide-ResNet-28-10 | 12.82 |
Shake-Shake (26 2x96d) | 10.64 | |
CIFAR-100 | Wide-ResNet-28-10 | 16.73 |
Shake-Shake (26 2x96d) | 15.31 | |
PyramidNet+ShakeDrop | 10.94 | |
SVHN | Wide-ResNet-28-10 | 1.18 |
Shake-Shake (26 2x96d) | 1.13 | |
Reduced SVHN | Wide-ResNet-28-10 | 7.83 |
Shake-Shake (26 2x96d) | 6.46 |
Scripts to reproduce results are located in scripts/table_*.sh
. One argument, the model name, is required for all of the scripts. The available options are those reported for each dataset in Table 2 of the paper, among the choices: wrn_28_10, ss_32, ss_96, ss_112, pyramid_net
. Hyperparamaters are also located inside each script file.
For example, to reproduce CIFAR-10 results on Wide-ResNet-28-10:
bash scripts/table_1_cifar10.sh wrn_28_10
To reproduce Reduced SVHN results on Shake-Shake (26 2x96d):
bash scripts/table_4_svhn.sh rsvhn_ss_96
A good place to start is Reduced SVHN on Wide-ResNet-28-10 which can complete in under 10 minutes on a Titan XP GPU reaching 91%+ test accuracy.
Running the larger models on 1800 epochs may require multiple days of training. For example, CIFAR-10 PyramidNet+ShakeDrop takes around 9 days on a Tesla V100 GPU.
Run PBA search on Wide-ResNet-40-2 with the file scripts/search.sh
. One argument, the dataset name, is required. Choices are rsvhn
or rcifar10
.
A partial GPU size is specified to launch multiple trials on the same GPU. Reduced SVHN takes around an hour on a Titan XP GPU, and Reduced CIFAR-10 takes around 5 hours.
CUDA_VISIBLE_DEVICES=0 bash scripts/search.sh rsvhn
The resulting schedules used in search can be retreived from the Ray result directory, and the log files can be converted into policy schedules with the parse_log()
function in pba/utils.py
. For example, policy schedule learned on Reduced CIFAR-10 over 200 epochs is split into probability and magnitude hyperparameter values (the two values for each augmentation operation are merged) and visualized below:
Probability Hyperparameters over Time | Magnitude Hyperparameters over Time |
---|---|
If you use PBA in your research, please cite:
@inproceedings{ho2019pba,
title = {Population Based Augmentation: Efficient Learning of Augmentation Policy Schedules},
author = {Daniel Ho and
Eric Liang and
Ion Stoica and
Pieter Abbeel and
Xi Chen
},
booktitle = {ICML},
year = {2019}
}