Skip to content

arcelien/pba

Repository files navigation

Population Based Augmentation (PBA)

New: Visualize PBA and applied augmentations with the notebook pba.ipynb!

Now with Python 3 support.

Table of Contents

  1. Introduction
  2. Getting Started
  3. Reproduce Results
  4. Run PBA Search
  5. Citation

Introduction

Population Based Augmentation (PBA) is a algorithm that quickly and efficiently learns data augmentation functions for neural network training. PBA matches state-of-the-art results on CIFAR with one thousand times less compute, enabling researchers and practitioners to effectively learn new augmentation policies using a single workstation GPU.

This repository contains code for the work "Population Based Augmentation: Efficient Learning of Augmentation Schedules" (http://arxiv.org/abs/1905.05393) in TensorFlow and Python. It includes training of models with the reported augmentation schedules and discovery of new augmentation policy schedules.

See below for a visualization of our augmentation strategy.

Getting Started

Code supports Python 2 and 3.

Install requirements

pip install -r requirements.txt

Download CIFAR-10/CIFAR-100 datasets

bash datasets/cifar10.sh
bash datasets/cifar100.sh

Reproduce Results

Dataset Model Test Error (%)
CIFAR-10 Wide-ResNet-28-10 2.58
Shake-Shake (26 2x32d) 2.54
Shake-Shake (26 2x96d) 2.03
Shake-Shake (26 2x112d) 2.03
PyramidNet+ShakeDrop 1.46
Reduced CIFAR-10 Wide-ResNet-28-10 12.82
Shake-Shake (26 2x96d) 10.64
CIFAR-100 Wide-ResNet-28-10 16.73
Shake-Shake (26 2x96d) 15.31
PyramidNet+ShakeDrop 10.94
SVHN Wide-ResNet-28-10 1.18
Shake-Shake (26 2x96d) 1.13
Reduced SVHN Wide-ResNet-28-10 7.83
Shake-Shake (26 2x96d) 6.46

Scripts to reproduce results are located in scripts/table_*.sh. One argument, the model name, is required for all of the scripts. The available options are those reported for each dataset in Table 2 of the paper, among the choices: wrn_28_10, ss_32, ss_96, ss_112, pyramid_net. Hyperparamaters are also located inside each script file.

For example, to reproduce CIFAR-10 results on Wide-ResNet-28-10:

bash scripts/table_1_cifar10.sh wrn_28_10

To reproduce Reduced SVHN results on Shake-Shake (26 2x96d):

bash scripts/table_4_svhn.sh rsvhn_ss_96

A good place to start is Reduced SVHN on Wide-ResNet-28-10 which can complete in under 10 minutes on a Titan XP GPU reaching 91%+ test accuracy.

Running the larger models on 1800 epochs may require multiple days of training. For example, CIFAR-10 PyramidNet+ShakeDrop takes around 9 days on a Tesla V100 GPU.

Run PBA Search

Run PBA search on Wide-ResNet-40-2 with the file scripts/search.sh. One argument, the dataset name, is required. Choices are rsvhn or rcifar10.

A partial GPU size is specified to launch multiple trials on the same GPU. Reduced SVHN takes around an hour on a Titan XP GPU, and Reduced CIFAR-10 takes around 5 hours.

CUDA_VISIBLE_DEVICES=0 bash scripts/search.sh rsvhn

The resulting schedules used in search can be retreived from the Ray result directory, and the log files can be converted into policy schedules with the parse_log() function in pba/utils.py. For example, policy schedule learned on Reduced CIFAR-10 over 200 epochs is split into probability and magnitude hyperparameter values (the two values for each augmentation operation are merged) and visualized below:

Probability Hyperparameters over Time Magnitude Hyperparameters over Time

Citation

If you use PBA in your research, please cite:

@inproceedings{ho2019pba,
  title     = {Population Based Augmentation: Efficient Learning of Augmentation Policy Schedules},
  author    = {Daniel Ho and
               Eric Liang and
               Ion Stoica and
               Pieter Abbeel and
               Xi Chen
  },
  booktitle = {ICML},
  year      = {2019}
}