Skip to content

Commit

Permalink
Resnet distribution strategies (tensorflow#3887)
Browse files Browse the repository at this point in the history
* begin transfer from contrib fork

more changes to resnet_run_loop

use AUTOTUNE in prefetch

first pass at resnet with functional distribution strategies

fix syntax error

delint

aesthetic tweaks

delint and fix typos

rip multi_gpu flag out of resnet entirely. Subject to saved model load verification

update cifar10 and imagenet tests to reflect that the model function no longer need to know about multi_gpu

fix imagenet test

start addressing PR comments

more PR response work

* misc tweaks

* add a comment

* final pr tweaks

* fix parsers
  • Loading branch information
Taylor Robie authored Apr 12, 2018
1 parent ad7755c commit 32aa656
Show file tree
Hide file tree
Showing 10 changed files with 195 additions and 257 deletions.
2 changes: 1 addition & 1 deletion official/mnist/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ class MNISTArgParser(argparse.ArgumentParser):

def __init__(self):
super(MNISTArgParser, self).__init__(parents=[
parsers.BaseParser(),
parsers.BaseParser(multi_gpu=True, num_gpu=False),
parsers.ImageModelParser(),
parsers.ExportParser(),
])
Expand Down
10 changes: 10 additions & 0 deletions official/resnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,13 @@ Other versions and formats:
* [ResNet-v2-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnet_v2_imagenet_savedmodel.tar.gz)
* [ResNet-v1-ImageNet Checkpoint](http://download.tensorflow.org/models/official/resnet_v1_imagenet_checkpoint.tar.gz)
* [ResNet-v1-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnet_v1_imagenet_savedmodel.tar.gz)

## Compute Devices
Training is accomplished using the DistributionStrategies API. (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/distribute/README.md)

The appropriate distribution strategy is chosen based on the `--num_gpus` flag. By default this flag is one if TensorFlow is compiled with CUDA, and zero otherwise.

num_gpus:
+ 0: Use OneDeviceStrategy and train on CPU.
+ 1: Use OneDeviceStrategy and train on GPU.
+ 2+: Use MirroredStrategy (data parallelism) to distribute a batch between devices.
16 changes: 3 additions & 13 deletions official/resnet/cifar10_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,34 +103,25 @@ def preprocess_image(image, is_training):
return image


def input_fn(is_training, data_dir, batch_size, num_epochs=1,
num_parallel_calls=1, multi_gpu=False):
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
"""Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
Args:
is_training: A boolean denoting whether the input is for training.
data_dir: The directory containing the input data.
batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset.
num_parallel_calls: The number of records that are processed in parallel.
This can be optimized per data set but for generally homogeneous data
sets, should be approximately the number of available CPU cores.
multi_gpu: Whether this is run multi-GPU. Note that this is only required
currently to handle the batch leftovers, and can be removed
when that is handled directly by Estimator.
Returns:
A dataset that can be used for iteration.
"""
filenames = get_filenames(is_training, data_dir)
dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES)

num_images = is_training and _NUM_IMAGES['train'] or _NUM_IMAGES['validation']

return resnet_run_loop.process_record_dataset(
dataset, is_training, batch_size, _NUM_IMAGES['train'],
parse_record, num_epochs, num_parallel_calls,
examples_per_epoch=num_images, multi_gpu=multi_gpu)
parse_record, num_epochs,
)


def get_synth_input_fn():
Expand Down Expand Up @@ -221,7 +212,6 @@ def loss_filter_fn(_):
version=params['version'],
loss_scale=params['loss_scale'],
loss_filter_fn=loss_filter_fn,
multi_gpu=params['multi_gpu'],
dtype=params['dtype']
)

Expand Down
112 changes: 44 additions & 68 deletions official/resnet/cifar10_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,87 +71,63 @@ def test_dataset_input_fn(self):
for pixel in row:
self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3)

def _cifar10_model_fn_helper(self, mode, version, dtype, multi_gpu=False):
with tf.Graph().as_default() as g:
input_fn = cifar10_main.get_synth_input_fn()
dataset = input_fn(True, '', _BATCH_SIZE)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
spec = cifar10_main.cifar10_model_fn(
features, labels, mode, {
'dtype': dtype,
'resnet_size': 32,
'data_format': 'channels_last',
'batch_size': _BATCH_SIZE,
'version': version,
'loss_scale': 128 if dtype == tf.float16 else 1,
'multi_gpu': multi_gpu
})

predictions = spec.predictions
self.assertAllEqual(predictions['probabilities'].shape,
(_BATCH_SIZE, 10))
self.assertEqual(predictions['probabilities'].dtype, tf.float32)
self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
self.assertEqual(predictions['classes'].dtype, tf.int64)

if mode != tf.estimator.ModeKeys.PREDICT:
loss = spec.loss
self.assertAllEqual(loss.shape, ())
self.assertEqual(loss.dtype, tf.float32)

if mode == tf.estimator.ModeKeys.EVAL:
eval_metric_ops = spec.eval_metric_ops
self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)

for v in tf.trainable_variables():
self.assertEqual(v.dtype.base_dtype, tf.float32)

tensors_to_check = ('initial_conv:0', 'block_layer1:0', 'block_layer2:0',
'block_layer3:0', 'final_reduce_mean:0',
'final_dense:0')

for tensor_name in tensors_to_check:
tensor = g.get_tensor_by_name('resnet_model/' + tensor_name)
self.assertEqual(tensor.dtype, dtype,
'Tensor {} has dtype {}, while dtype {} was '
'expected'.format(tensor, tensor.dtype,
dtype))

def cifar10_model_fn_helper(self, mode, version, multi_gpu=False):
self._cifar10_model_fn_helper(mode=mode, version=version, dtype=tf.float32,
multi_gpu=multi_gpu)
self._cifar10_model_fn_helper(mode=mode, version=version, dtype=tf.float16,
multi_gpu=multi_gpu)
def cifar10_model_fn_helper(self, mode, version, dtype):
input_fn = cifar10_main.get_synth_input_fn()
dataset = input_fn(True, '', _BATCH_SIZE)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
spec = cifar10_main.cifar10_model_fn(
features, labels, mode, {
'dtype': dtype,
'resnet_size': 32,
'data_format': 'channels_last',
'batch_size': _BATCH_SIZE,
'version': version,
'loss_scale': 128 if dtype == tf.float16 else 1,
})

predictions = spec.predictions
self.assertAllEqual(predictions['probabilities'].shape,
(_BATCH_SIZE, 10))
self.assertEqual(predictions['probabilities'].dtype, tf.float32)
self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
self.assertEqual(predictions['classes'].dtype, tf.int64)

if mode != tf.estimator.ModeKeys.PREDICT:
loss = spec.loss
self.assertAllEqual(loss.shape, ())
self.assertEqual(loss.dtype, tf.float32)

if mode == tf.estimator.ModeKeys.EVAL:
eval_metric_ops = spec.eval_metric_ops
self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)

def test_cifar10_model_fn_train_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1)

def test_cifar10_model_fn_trainmode__v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2)

def test_cifar10_model_fn_train_mode_multi_gpu_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1,
multi_gpu=True)
dtype=tf.float32)

def test_cifar10_model_fn_train_mode_multi_gpu_v2(self):
def test_cifar10_model_fn_trainmode__v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2,
multi_gpu=True)
dtype=tf.float32)

def test_cifar10_model_fn_eval_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1)
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1,
dtype=tf.float32)

def test_cifar10_model_fn_eval_mode_v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2)
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2,
dtype=tf.float32)

def test_cifar10_model_fn_predict_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1)
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1,
dtype=tf.float32)

def test_cifar10_model_fn_predict_mode_v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2)
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2,
dtype=tf.float32)

def _test_cifar10model_shape(self, version):
batch_size = 135
Expand Down
16 changes: 3 additions & 13 deletions official/resnet/imagenet_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,21 +154,14 @@ def parse_record(raw_record, is_training):
return image, label


def input_fn(is_training, data_dir, batch_size, num_epochs=1,
num_parallel_calls=1, multi_gpu=False):
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
"""Input function which provides batches for train or eval.
Args:
is_training: A boolean denoting whether the input is for training.
data_dir: The directory containing the input data.
batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset.
num_parallel_calls: The number of records that are processed in parallel.
This can be optimized per data set but for generally homogeneous data
sets, should be approximately the number of available CPU cores.
multi_gpu: Whether this is run multi-GPU. Note that this is only required
currently to handle the batch leftovers, and can be removed
when that is handled directly by Estimator.
Returns:
A dataset that can be used for iteration.
Expand All @@ -180,15 +173,13 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
# Shuffle the input files
dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES)

num_images = is_training and _NUM_IMAGES['train'] or _NUM_IMAGES['validation']

# Convert to individual records
dataset = dataset.flat_map(tf.data.TFRecordDataset)

return resnet_run_loop.process_record_dataset(
dataset, is_training, batch_size, _SHUFFLE_BUFFER, parse_record,
num_epochs, num_parallel_calls, examples_per_epoch=num_images,
multi_gpu=multi_gpu)
num_epochs
)


def get_synth_input_fn():
Expand Down Expand Up @@ -300,7 +291,6 @@ def imagenet_model_fn(features, labels, mode, params):
version=params['version'],
loss_scale=params['loss_scale'],
loss_filter_fn=None,
multi_gpu=params['multi_gpu'],
dtype=params['dtype']
)

Expand Down
114 changes: 46 additions & 68 deletions official/resnet/imagenet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,88 +180,66 @@ def test_tensor_shapes_resnet_200_with_gpu_v1(self):
def test_tensor_shapes_resnet_200_with_gpu_v2(self):
self.tensor_shapes_helper(200, version=2, with_gpu=True)

def _resnet_model_fn_helper(self, mode, version, dtype, multi_gpu):
def resnet_model_fn_helper(self, mode, version, dtype):
"""Tests that the EstimatorSpec is given the appropriate arguments."""
with tf.Graph().as_default() as g:
tf.train.create_global_step()

input_fn = imagenet_main.get_synth_input_fn()
dataset = input_fn(True, '', _BATCH_SIZE)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
spec = imagenet_main.imagenet_model_fn(
features, labels, mode, {
'dtype': dtype,
'resnet_size': 50,
'data_format': 'channels_last',
'batch_size': _BATCH_SIZE,
'version': version,
'loss_scale': 128 if dtype == tf.float16 else 1,
'multi_gpu': multi_gpu,
})

predictions = spec.predictions
self.assertAllEqual(predictions['probabilities'].shape,
(_BATCH_SIZE, _LABEL_CLASSES))
self.assertEqual(predictions['probabilities'].dtype, tf.float32)
self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
self.assertEqual(predictions['classes'].dtype, tf.int64)

if mode != tf.estimator.ModeKeys.PREDICT:
loss = spec.loss
self.assertAllEqual(loss.shape, ())
self.assertEqual(loss.dtype, tf.float32)

if mode == tf.estimator.ModeKeys.EVAL:
eval_metric_ops = spec.eval_metric_ops
self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)

tensors_to_check = ('initial_conv:0', 'initial_max_pool:0',
'block_layer1:0', 'block_layer2:0',
'block_layer3:0', 'block_layer4:0',
'final_reduce_mean:0', 'final_dense:0')

for tensor_name in tensors_to_check:
tensor = g.get_tensor_by_name('resnet_model/' + tensor_name)
self.assertEqual(tensor.dtype, dtype,
'Tensor {} has dtype {}, while dtype {} was '
'expected'.format(tensor, tensor.dtype,
dtype))

def resnet_model_fn_helper(self, mode, version, multi_gpu=False):
self._resnet_model_fn_helper(mode=mode, version=version, dtype=tf.float32,
multi_gpu=multi_gpu)
self._resnet_model_fn_helper(mode=mode, version=version, dtype=tf.float16,
multi_gpu=multi_gpu)
tf.train.create_global_step()

input_fn = imagenet_main.get_synth_input_fn()
dataset = input_fn(True, '', _BATCH_SIZE)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
spec = imagenet_main.imagenet_model_fn(
features, labels, mode, {
'dtype': dtype,
'resnet_size': 50,
'data_format': 'channels_last',
'batch_size': _BATCH_SIZE,
'version': version,
'loss_scale': 128 if dtype == tf.float16 else 1,
})

predictions = spec.predictions
self.assertAllEqual(predictions['probabilities'].shape,
(_BATCH_SIZE, _LABEL_CLASSES))
self.assertEqual(predictions['probabilities'].dtype, tf.float32)
self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
self.assertEqual(predictions['classes'].dtype, tf.int64)

if mode != tf.estimator.ModeKeys.PREDICT:
loss = spec.loss
self.assertAllEqual(loss.shape, ())
self.assertEqual(loss.dtype, tf.float32)

if mode == tf.estimator.ModeKeys.EVAL:
eval_metric_ops = spec.eval_metric_ops
self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)

def test_resnet_model_fn_train_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1)

def test_resnet_model_fn_train_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2)

def test_resnet_model_fn_train_mode_multi_gpu_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1,
multi_gpu=True)
dtype=tf.float32)

def test_resnet_model_fn_train_mode_multi_gpu_v2(self):
def test_resnet_model_fn_train_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2,
multi_gpu=True)
dtype=tf.float32)

def test_resnet_model_fn_eval_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1)
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1,
dtype=tf.float32)

def test_resnet_model_fn_eval_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2)
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2,
dtype=tf.float32)

def test_resnet_model_fn_predict_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1)
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1,
dtype=tf.float32)

def test_resnet_model_fn_predict_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2)
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2,
dtype=tf.float32)

def _test_imagenetmodel_shape(self, version):
batch_size = 135
Expand Down
Loading

0 comments on commit 32aa656

Please sign in to comment.