Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resnet distribution strategies #3887

Merged
merged 5 commits into from
Apr 12, 2018
Merged

Conversation

robieta
Copy link
Contributor

@robieta robieta commented Apr 5, 2018

This PR pulls the work of the distribution strategies team back into official/resnet. Specifically it removes replicate model function and tower optimizer, incorporates some changes to the dataset pipeline, and uses various distribution strategies in the estimators.

@guptapriya Would you be so kind as to check that this is a faithful port of your code?

So far I have only performed one run: ResNet_50_v2 on ImageNet. It converged to 76% in 2 days. I will be performing a full battery of runs in the coming days.

@robieta robieta requested review from karmel, guptapriya and k-w-w April 5, 2018 17:19
@robieta robieta requested a review from nealwu as a code owner April 5, 2018 17:19
Copy link
Contributor

@k-w-w k-w-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a short description of DistributionStrategy to the README? The first time I heard the term, it wasn't immediately obvious that it referred to multi gpus. And looking up "DistributionStrategy" on google didn't really help.

@@ -118,6 +118,7 @@ def __init__(self, add_help=False, data_dir=True, model_dir=True,
metavar="<BS>"
)

# TODO(taylorrobie@): depricate and only use DistributionStrategies
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: (sp) deprecate

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

@@ -151,6 +152,7 @@ def __init__(self, add_help=False, num_parallel_calls=True, inter_op=True,
intra_op=True, use_synthetic_data=True, max_train_steps=True):
super(PerformanceParser, self).__init__(add_help=add_help)

# TODO(taylorrobie@): depricate and only use DistributionStrategies
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: (sp) deprecate

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

tf.contrib.data.map_and_batch(
lambda value: parse_record_fn(value, is_training),
batch_size=per_device_batch_size,
num_parallel_batches=1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is num_parallel_batches set to 1 (wouldn't more improve performance)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It still parallelizes the lambda call. So we probably only need >1 if there are stragglers, n_cores > batch_size.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you want to make this more generic, the ideal thing is num_parallel_batches = num_cores / batch_size
Since usually batch_size > num_cores, i used 1. But yeah you could do that division and take the cieling or something?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interestingly, when I had tried this with replicate_model_fn alone, it actually slowed performance. Same for tf_cnn. See tensorflow/benchmarks#137 , and the abandoned branch in which I attempted- https://github.com/tensorflow/models/compare/feat/contrib-data . This is JFYI assuming that DistStrat gets better performance, but not a bad idea to take in num_cpus in any case, as Transformer uses that as well, will be useful elsewhere.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's definitely interesting. We saw a performance bump both in OneDeviceStrategy and MirroredStrategy with this approach.. agreed that it might still be worthwhile to take in the num cores and use that to get the num_parallel_batches

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just tested this on a trivial imagenet resnet (so any limit is the input pipeline), and found no difference with num_parallel_batches. (This is 32 cores, 4xP100, batch_size=512.) I'm hard pressed to think of a case where num_cores < batch_size makes sense, so I'm inclined to leave it at 1 to keep the code simple.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds good to me if you add a comment describing the other approach. CC @mrry in case you have thoughts.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting it to 1 seems like a fine way to go. Any larger would probably lead to congestion on the threadpool queues. The ideal number for a batch size of 512 is probably between 0 and 1, and there's an outstanding bug to support more precise control of the parallelism here. /cc @jsimsa

accuracy = tf.metrics.accuracy(
tf.argmax(labels, axis=1), predictions['classes'])
else:
# Metrics are currently no compatible with distribution strategies
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: not compatible

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

model_function,
loss_reduction=tf.losses.Reduction.MEAN)
# TODO(taylorrobie@): remove when per_device is no longer needed.
assign_multi_gpu(flags)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this relate to per_device?

Also, it appears flags.multi_gpu is only used in the line warn_on_multi_gpu_export(flags.multi_gpu). This can be changed to use flags.use_distribution_strategy, and the multi_gpu flag can be completely removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turns out we can probably remove multi_gpu altogether. Karmel just has to check that it doesn't break saved models.

Copy link
Contributor

@guptapriya guptapriya left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @robieta! This looks great! Left some small comments and suggestions. Definitely a faithful port of the example, thank you.

@@ -110,14 +109,6 @@ def test_cifar10_model_fn_train_mode_v1(self):
def test_cifar10_model_fn_trainmode__v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2)

def test_cifar10_model_fn_train_mode_multi_gpu_v1(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we instead change these tests to run distributed strategy version on multiple GPUs?

import tensorflow as tf # pylint: disable=g-bad-import-order
# pylint: disable=g-bad-import-order
import tensorflow as tf
from tensorflow.contrib.distribute.python import mirrored_strategy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually you shouldn't need this. you can directly use them as
tf.contrib.distribute.MirroredStrategy
tf.contrib.distribute.OneDeviceStrategy

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent.

tf.contrib.data.map_and_batch(
lambda value: parse_record_fn(value, is_training),
batch_size=per_device_batch_size,
num_parallel_batches=1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you want to make this more generic, the ideal thing is num_parallel_batches = num_cores / batch_size
Since usually batch_size > num_cores, i used 1. But yeah you could do that division and take the cieling or something?


# Operations between the final prefetch and the get_next call to the iterator
# will happen synchronously during run time. We prefetch here again to
# background all of the above processing work and keep it out of the
# critical training path.
dataset = dataset.prefetch(1)
dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add instructional comments-- what is autotune? What does it do?


update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
train_op = tf.group(optimizer.minimize(loss, global_step), update_ops)
else:
train_op = None

accuracy = tf.metrics.accuracy(
tf.argmax(labels, axis=1), predictions['classes'])
if not distribute_lib.has_distribution_strategy():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be tf.contrib.distribute.has_distribution_strategy()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can check in code whether we have a dist strat via a tf function, why do we need to pass around use_distribution_strategies?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're absolutely right @karmel. Anytime one is in the distribution strategy scope, the has_distribution_strategy check should work. I believe in our earlier code, we didn't have the input function running inside the scope, so we had to pass around this boolean. But Since then we've moved input processing to be under distribution scope as well and can now use has_distribution_strategy() almost everywhere.

I just checked where all it's used. I think it can be replaced with has_distribution_strategy in all places except this one in resnet_main . And you already suggested removing that check from there entirely. So I think we can get rid of this flag entirely!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, that flag is gone.

from tensorflow.python.client import device_lib # pylint: disable=g-import-not-at-top

local_device_protos = device_lib.list_local_devices()
num_gpus = sum([1 for d in local_device_protos if d.device_type == 'GPU'])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i wonder if this check is still useful in some form... for e.g. should we check that actual number of gpus available >= gpus_for_distribution_strategy

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like that check should live under the hood in DistStrat-- we had to break our own rules and import from the private API here. Is it feasible to do that on the DistStrat side? Or just ignore for now and let the error bubble up when it is hit...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on living under DistStrat.


Args:
add_help: Create the "--help" flag. False if class instance is a parent.
batch_size: Create a flag to specify the batch size. (Instead of the one
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this batch size here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Global. I will make that more clear.

)

self.add_argument(
"--gpus_for_distribution_strategy", "-gds",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wonder if we should just rename gpus_for_distribution_strategy -> num_gpus everywhere?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes please. And I think we can also plausibly use num_gpus=0, num_gpus=1 to represent the fact that we want diststrat even in single-device cases. That would allow us to ditch the second arg, and roll this back into the main parser, where we already have multi_gpu.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -104,20 +104,18 @@ def preprocess_image(image, is_training):


def input_fn(is_training, data_dir, batch_size, num_epochs=1,
num_parallel_calls=1, multi_gpu=False):
use_distribution_strategy=False,
gpus_for_distribution_strategy=1):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we should generalize, rather than tie ourselves to the particular name of how we are distributing. Can we also reduce this to num_gpus, and then use dist strat if num_gpus > 1? That won't extend to future implementations, but we'll have to change this for the future implementations anyways.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Ditto throughout.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's reasonable. I made the default 1 if tf.test.is_built_with_cuda() else 0. My justification is that if you don't specify --num_gpus and you have a gpu people generally expect tensorflow to do work there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1
Think of the params etc that we had as something done in crunch time :) so it would be great to change it to whatever makes most sense for a user.


def test_resnet_model_fn_train_mode_multi_gpu_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2,
multi_gpu=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing tests without replacement? That seems unlike you, @robieta .

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is model_fn used to be device aware because of replicate_model_fn, and now it isn't. That's why I can safely remove that test. And we don't currently have the infrastructure for me to set up a multi-gpu end-to-end test.


from official.resnet import resnet_model
from official.utils.arg_parsers import parsers
from official.utils.export import export
from official.utils.logs import hooks_helper
from official.utils.logs import logger
# pylint: enable=g-bad-import-order
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we move to top-level imports as @guptapriya notes above, we can just switch back to the single-line import order.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed.


Returns:
Dataset of (image, label) pairs ready for iteration.
"""

# TODO(taylorrobie@) remove when DistributionStrategies uses global batch size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not leave TODOs in the public code; a comment explaining that this is only necessary for a short period because etc etc is sufficient. Also, nit, for future reference: TODO(taylorrobie), and, this is public code, so, robieta.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duly noted.

total_examples = num_epochs * examples_per_epoch
dataset = dataset.take(batch_size * (total_examples // batch_size))
dataset = dataset.take(
per_device_batch_size * (total_examples // per_device_batch_size))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still necessary? I would imagine not? This was originally a fix for the fact that replicate_model_fn would error out, as noted in the comment (which should also be updated in the unexpected case that this is still relevant). It would be great to remove this, because then we can not pass multi-gpu knowledge this far in, and compute batch size closer to the top of the processing, pass in the desired batch_size here without this func caring whether it's global or per device.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Testing confirms that we can get rid of this entire section.

'version': flags.version,
})

if flags.benchmark_log_dir is not None:
benchmark_logger = logger.BenchmarkLogger(flags.benchmark_log_dir)
benchmark_logger.log_run_info("resnet")
benchmark_logger.log_run_info('resnet')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I swear I've changed this in about 5 branches now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No kidding.

@@ -118,6 +118,7 @@ def __init__(self, add_help=False, data_dir=True, model_dir=True,
metavar="<BS>"
)

# TODO(taylorrobie@): deprecate and only use DistributionStrategies
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto on Todos.

@@ -151,6 +152,7 @@ def __init__(self, add_help=False, num_parallel_calls=True, inter_op=True,
intra_op=True, use_synthetic_data=True, max_train_steps=True):
super(PerformanceParser, self).__init__(add_help=add_help)

# TODO(taylorrobie@): deprecate and only use DistributionStrategies
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now, only MNIST uses, correct? Can you follow up with making that use DistStrat as well? And, in theory, WideDeep should be easy, because it's just estimators... is that true? Can you check and update as well if so?

)

self.add_argument(
"--gpus_for_distribution_strategy", "-gds",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes please. And I think we can also plausibly use num_gpus=0, num_gpus=1 to represent the fact that we want diststrat even in single-device cases. That would allow us to ditch the second arg, and roll this back into the main parser, where we already have multi_gpu.

"--gpus_for_distribution_strategy", "-gds",
type=int, default=2,
help="[default: %(default)s] How many GPUs to use with the "
"DistributionStrategies API.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need links to the docs somewhere, and this seems like a good place. Maybe also in the comment noting that multi-GPU is experimental.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@karmel
Copy link
Contributor

karmel commented Apr 6, 2018

I just checked with a saved_model exported from this PR, and it does indeed appear that this fixes the problem with replicate_model_fn where you had to export the savedmodel on single gpu even if training on multi. CC @isaprykin and @k-w-w , who are interested parties. @robieta -- can you remove warn_on_multi_gpu as well?

Copy link
Contributor

@guptapriya guptapriya left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the thorough review @karmel ! We should definitely make this more usable than our version which was done in a hurry. I left my responses to some of your questions.

@@ -104,20 +104,18 @@ def preprocess_image(image, is_training):


def input_fn(is_training, data_dir, batch_size, num_epochs=1,
num_parallel_calls=1, multi_gpu=False):
use_distribution_strategy=False,
gpus_for_distribution_strategy=1):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1
Think of the params etc that we had as something done in crunch time :) so it would be great to change it to whatever makes most sense for a user.

tf.contrib.data.map_and_batch(
lambda value: parse_record_fn(value, is_training),
batch_size=per_device_batch_size,
num_parallel_batches=1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's definitely interesting. We saw a performance bump both in OneDeviceStrategy and MirroredStrategy with this approach.. agreed that it might still be worthwhile to take in the num cores and use that to get the num_parallel_batches


update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
train_op = tf.group(optimizer.minimize(loss, global_step), update_ops)
else:
train_op = None

accuracy = tf.metrics.accuracy(
tf.argmax(labels, axis=1), predictions['classes'])
if not distribute_lib.has_distribution_strategy():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're absolutely right @karmel. Anytime one is in the distribution strategy scope, the has_distribution_strategy check should work. I believe in our earlier code, we didn't have the input function running inside the scope, so we had to pass around this boolean. But Since then we've moved input processing to be under distribution scope as well and can now use has_distribution_strategy() almost everywhere.

I just checked where all it's used. I think it can be replaced with has_distribution_strategy in all places except this one in resnet_main . And you already suggested removing that check from there entirely. So I think we can get rid of this flag entirely!

'Found {} GPUs with a batch size of {}; try --batch_size={} instead.'
).format(num_gpus, batch_size, batch_size - remainder)
raise ValueError(err)
if use_distribution_strategy and gpus_for_distribution_strategy > 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

@@ -355,21 +370,35 @@ def resnet_main(flags, model_function, input_function, shape=None):
allow_soft_placement=True)

# Set up a RunConfig to save checkpoint and set session config.
run_config = tf.estimator.RunConfig().replace(save_checkpoints_secs=1e9,
session_config=session_config)
if not flags.use_distribution_strategy:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's a good idea. When we originally added these flags, we were not removing the previous multi gpu approach. So we didn't want to enable if by default. But I think now it makes sense. OneDeviceStrategy should pretty much do what having no distribution strategy does.

distribution = mirrored_strategy.MirroredStrategy(
num_gpus=flags.gpus_for_distribution_strategy
)
run_config = tf.estimator.RunConfig(distribute=distribution).replace(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes, sorry we just changed it after a discussion. you're right, it's called train_distribute now. thanks for catching!

num_gpus=flags.gpus_for_distribution_strategy
)
run_config = tf.estimator.RunConfig(distribute=distribution).replace(
save_checkpoints_secs=1e9, session_config=session_config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I am not a 100% sure. We did test with setting those as env variables during performance tuning but did not find significant benefits. But I don't believe we changed anything in dist strategy itself to support/not support them explicitly. @isaprykin perhaps can shed more light?

)

self.add_argument(
"--gpus_for_distribution_strategy", "-gds",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

"--gpus_for_distribution_strategy", "-gds",
type=int, default=2,
help="[default: %(default)s] How many GPUs to use with the "
"DistributionStrategies API.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@robieta robieta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've started addressing comments. I will do another round shortly.

@@ -104,20 +104,18 @@ def preprocess_image(image, is_training):


def input_fn(is_training, data_dir, batch_size, num_epochs=1,
num_parallel_calls=1, multi_gpu=False):
use_distribution_strategy=False,
gpus_for_distribution_strategy=1):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's reasonable. I made the default 1 if tf.test.is_built_with_cuda() else 0. My justification is that if you don't specify --num_gpus and you have a gpu people generally expect tensorflow to do work there.

import tensorflow as tf # pylint: disable=g-bad-import-order
# pylint: disable=g-bad-import-order
import tensorflow as tf
from tensorflow.contrib.distribute.python import mirrored_strategy
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent.


from official.resnet import resnet_model
from official.utils.arg_parsers import parsers
from official.utils.export import export
from official.utils.logs import hooks_helper
from official.utils.logs import logger
# pylint: enable=g-bad-import-order
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed.

total_examples = num_epochs * examples_per_epoch
dataset = dataset.take(batch_size * (total_examples // batch_size))
dataset = dataset.take(
per_device_batch_size * (total_examples // per_device_batch_size))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Testing confirms that we can get rid of this entire section.

def input_fn(is_training, data_dir, batch_size, # pylint: disable=unused-argument, missing-docstring
use_distribution_strategy=False,
gpus_for_distribution_strategy=1, *args, **kwargs): # pylint: disable=unused-argument
# TODO(taylorrobie@) cull DistributionStrategies uses global batch size
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, once the input_fn is device blind there's a whole lot of variable passing that can be ripped out. So satisfying.

@@ -355,21 +370,35 @@ def resnet_main(flags, model_function, input_function, shape=None):
allow_soft_placement=True)

# Set up a RunConfig to save checkpoint and set session config.
run_config = tf.estimator.RunConfig().replace(save_checkpoints_secs=1e9,
session_config=session_config)
if not flags.use_distribution_strategy:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

session_config=session_config)
if not flags.use_distribution_strategy:
run_config = tf.estimator.RunConfig().replace(
save_checkpoints_secs=1e9, session_config=session_config)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine by me.

distribution = mirrored_strategy.MirroredStrategy(
num_gpus=flags.gpus_for_distribution_strategy
)
run_config = tf.estimator.RunConfig(distribute=distribution).replace(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perils of working from contrib. ^^

'version': flags.version,
})

if flags.benchmark_log_dir is not None:
benchmark_logger = logger.BenchmarkLogger(flags.benchmark_log_dir)
benchmark_logger.log_run_info("resnet")
benchmark_logger.log_run_info('resnet')
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No kidding.

)

self.add_argument(
"--gpus_for_distribution_strategy", "-gds",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@robieta robieta force-pushed the resnet_distribution_strategies branch from a819c2c to 4a1207a Compare April 6, 2018 22:11
@robieta robieta requested a review from a team as a code owner April 6, 2018 22:11
@robieta robieta force-pushed the resnet_distribution_strategies branch 2 times, most recently from 4bcd195 to 667a760 Compare April 9, 2018 18:41
@googlebot googlebot added cla: no and removed cla: yes labels Apr 9, 2018
@robieta
Copy link
Contributor Author

robieta commented Apr 9, 2018

It appears that some git tomfoolery has occurred. I will sort it out, and apologies to those of you who got dragged in as owners.

@robieta robieta force-pushed the resnet_distribution_strategies branch from 667a760 to 11192e9 Compare April 9, 2018 19:03
@tensorflow tensorflow deleted a comment from googlebot Apr 9, 2018
@karmel karmel added cla: yes and removed cla: no labels Apr 9, 2018
@robieta
Copy link
Contributor Author

robieta commented Apr 9, 2018

@karmel I think we're ready for you to take another pass.

@@ -99,14 +99,14 @@ class BaseParser(argparse.ArgumentParser):
model_dir: Create a flag for specifying the model file directory.
train_epochs: Create a flag to specify the number of training epochs.
epochs_between_evals: Create a flag to specify the frequency of testing.
batch_size: Create a flag to specify the batch size.
batch_size: Create a flag to specify the global batch size.
multi_gpu: Create a flag to allow the use of all available GPUs.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing doc for num_gpu

# critical training path. Setting buffer_size to tf.contrib.data.AUTOTUNE
# allows DistributionStrategies to adjust how many batches to fetch based
# on how many devices are present.
dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -122,7 +106,7 @@ def get_synth_input_fn(height, width, num_channels, num_classes):
An input_fn that can be used in place of a real one to return a dataset
that can be used for iteration.
"""
def input_fn(is_training, data_dir, batch_size, *args): # pylint: disable=unused-argument
def input_fn(is_training, data_dir, batch_size, *args, **kwargs): # pylint: disable=unused-argument,missing-docstring
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should be short enough to not require a docstring (< 10 lines); are you sure this extra disable is necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was erroring on the Kokoro lint, but looks like it isn't now. It will forever remain a mystery.

@@ -355,21 +370,35 @@ def resnet_main(flags, model_function, input_function, shape=None):
allow_soft_placement=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tagging @isaprykin on this one too-- do we need allow_soft_placement still?

multi_gpu: Create a flag to allow the use of all available GPUs.
hooks: Create a flag to specify hooks for logging.
"""

def __init__(self, add_help=False, data_dir=True, model_dir=True,
train_epochs=True, epochs_between_evals=True, batch_size=True,
multi_gpu=True, hooks=True):
multi_gpu=True, num_gpu=True, hooks=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it will be ripped out in a coming PR, but, for now, let's set the default for multi_gpu=False

if multi_gpu:
self.add_argument(
"--multi_gpu", action="store_true",
help="If set, run across all available GPUs."
)

if num_gpu:
self.add_argument(
"--num_gpus", "-ng",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we are reaching some upper bound of abbreviation strings, but, then again, I am the type that always prefers the fully explicit versions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

U wil pry my CLI abr frm my cld, ded hnds.

self.add_argument(
"--num_gpus", "-ng",
type=int,
default=1 if tf.test.is_built_with_cuda() else 0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test to make sure this default gets set correctly in the arg parsers test? I think there's a way to force one mode or the other... if not, have you at least confirmed that this does work correctly on GPU/CPU?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested manually and confirmed. Probably not worth the effort to make a formal test simply because at that point it's more of a test of tf.test than the model garden.

type=int,
default=1 if tf.test.is_built_with_cuda() else 0,
help="[default: %(default)s] How many GPUs to use with the "
"DistributionStrategies API.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a note here that reflects the details in the readme? Specifically, that 0==CPU, 1==GPU, default is what you built TF with.

Taylor Robie added 5 commits April 12, 2018 12:53
more changes to resnet_run_loop

use AUTOTUNE in prefetch

first pass at resnet with functional distribution strategies

fix syntax error

delint

aesthetic tweaks

delint and fix typos

rip multi_gpu flag out of resnet entirely. Subject to saved model load verification

update cifar10 and imagenet tests to reflect that the model function no longer need to know about multi_gpu

fix imagenet test

start addressing PR comments

more PR response work
@robieta robieta force-pushed the resnet_distribution_strategies branch from bebd187 to 2f41c72 Compare April 12, 2018 19:53
@robieta robieta merged commit 32aa656 into master Apr 12, 2018
@robieta robieta deleted the resnet_distribution_strategies branch April 12, 2018 21:22
robieta pushed a commit that referenced this pull request Apr 19, 2018
robieta pushed a commit that referenced this pull request Apr 19, 2018
omegafragger pushed a commit to omegafragger/models that referenced this pull request May 15, 2018
* begin transfer from contrib fork

more changes to resnet_run_loop

use AUTOTUNE in prefetch

first pass at resnet with functional distribution strategies

fix syntax error

delint

aesthetic tweaks

delint and fix typos

rip multi_gpu flag out of resnet entirely. Subject to saved model load verification

update cifar10 and imagenet tests to reflect that the model function no longer need to know about multi_gpu

fix imagenet test

start addressing PR comments

more PR response work

* misc tweaks

* add a comment

* final pr tweaks

* fix parsers
omegafragger pushed a commit to omegafragger/models that referenced this pull request May 15, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants