TF data augmentation on GPU
def augment(images, labels,
resize=None, # (width, height) tuple or None
rotate=0, # Maximum rotation angle in degrees
crop_probability=0, # How often we do crops
crop_min_percent=0.6, # Minimum linear dimension of a crop
crop_max_percent=1., # Maximum linear dimension of a crop
mixup=0): # Mixup coeffecient, see
if resize is not None:
images = tf.image.resize_bilinear(images, resize)
# My experiments showed that casting on GPU improves training performance
if images.dtype != tf.float32:
images = tf.image.convert_image_dtype(images, dtype=tf.float32)
images = tf.subtract(images, 0.5)
images = tf.multiply(images, 2.0)
labels = tf.to_float(labels)
with tf.name_scope('augmentation'):
shp = tf.shape(images)
batch_size, height, width = shp[0], shp[1], shp[2]
width = tf.cast(width, tf.float32)
height = tf.cast(height, tf.float32)
# The list of affine transformations that our image will go under.
# Every element is Nx8 tensor, where N is a batch size.
transforms = []
identity = tf.constant([1, 0, 0, 0, 1, 0, 0, 0], dtype=tf.float32)
if horizontal_flip:
coin = tf.less(tf.random_uniform([batch_size], 0, 1.0), 0.5)
flip_transform = tf.convert_to_tensor(
[-1., 0., width, 0., 1., 0., 0., 0.], dtype=tf.float32)
tf.tile(tf.expand_dims(flip_transform, 0), [batch_size, 1]),
tf.tile(tf.expand_dims(identity, 0), [batch_size, 1])))
if vertical_flip:
coin = tf.less(tf.random_uniform([batch_size], 0, 1.0), 0.5)
flip_transform = tf.convert_to_tensor(
[1, 0, 0, 0, -1, height, 0, 0], dtype=tf.float32)
tf.tile(tf.expand_dims(flip_transform, 0), [batch_size, 1]),
tf.tile(tf.expand_dims(identity, 0), [batch_size, 1])))
if rotate > 0:
angle_rad = rotate / 180 * math.pi
angles = tf.random_uniform([batch_size], -angle_rad, angle_rad)
angles, height, width))
if crop_probability > 0:
crop_pct = tf.random_uniform([batch_size], crop_min_percent,
left = tf.random_uniform([batch_size], 0, width * (1 - crop_pct))
top = tf.random_uniform([batch_size], 0, height * (1 - crop_pct))
crop_transform = tf.stack([
tf.zeros([batch_size]), top,
tf.zeros([batch_size]), crop_pct, left,
], 1)
coin = tf.less(
tf.random_uniform([batch_size], 0, 1.0), crop_probability)
tf.where(coin, crop_transform,
tf.tile(tf.expand_dims(identity, 0), [batch_size, 1])))
if transforms:
images = tf.contrib.image.transform(
interpolation='BILINEAR') # or 'NEAREST'
def cshift(values): # Circular shift in batch dimension
return tf.concat([values[-1:, ...], values[:-1, ...]], 0)
if mixup > 0:
mixup = 1.0 * mixup # Convert to float, as tf.distributions.Beta requires floats.
beta = tf.distributions.Beta(mixup, mixup)
lam = beta.sample(batch_size)
ll = tf.expand_dims(tf.expand_dims(tf.expand_dims(lam, -1), -1), -1)
images = ll * images + (1 - ll) * cshift(images)
labels = lam * labels + (1 - lam) * cshift(labels)
return images, labels
"""Usage example"""
# These can be any tensors of matching type and dimensions.
images = tf.placeholder(tf.uint8, shape=(None, None, None, 3))
labels = tf.placeholder(tf.uint64, shape=(None))
images, labels = augment(images, labels,
horizontal_flip=True, rotate=15, crop_probability=0.8, mixup=4)
# ... Now build your model and loss on top of images and labels
