Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fp16 support to official ResNet. #3687

Merged
merged 7 commits into from
Apr 9, 2018
Merged

Add fp16 support to official ResNet. #3687

merged 7 commits into from
Apr 9, 2018

Conversation

robieta
Copy link
Contributor

@robieta robieta commented Mar 21, 2018

This PR adds fp16 support to official/resnet. The vast majority of the work was already done by Reed a month ago (including wonderful comments), and this just rolls those changes into the current master.

Currently training is I/O bound; however synthetic data runs confirm the fp16 accelleration to ~4000 images/sec during training.

I will update when I have run results.

@robieta robieta requested review from karmel, nealwu, reedwm and tfboyd March 21, 2018 20:25
@robieta robieta requested a review from k-w-w as a code owner March 21, 2018 20:25
@karmel
Copy link
Contributor

karmel commented Mar 22, 2018

Quick thought before a more full review: there is a future in which we support a number of different quantization options, at least for some models. Is it feasible generalize this to a dtype being passed around rather than fixing to only fp16? You don't need to handle multiple dtypes for now, but at least not fix all the params and names to "fp16".

@tfboyd
Copy link
Member

tfboyd commented Mar 22, 2018 via email

@robieta
Copy link
Contributor Author

robieta commented Mar 23, 2018

My understanding is that once automatic fp16 scaling is ready all of this will be ripped out, and at that point passing dtype will be natural. If you're worried about the interface changing it's no trouble to have dtype as the CLI flag. The only thing that might be weird is loss_scaling, as that would still have to be an fp16 CLI flag.

@karmel
Copy link
Contributor

karmel commented Mar 23, 2018

fp16 today, int8 tomorrow. Let's generalize to dtype. Then allow for a loss_scale to be passed in for the selected dtype. If not passed in, select from a dict of defaults based on the selected dtype? Open to pros and cons on that, but I think we should anticipate more dtypes rather than fixing on the one currently ready.

@tfboyd
Copy link
Member

tfboyd commented Mar 23, 2018 via email

@robieta
Copy link
Contributor Author

robieta commented Apr 2, 2018

The interface now uses dtype and loss_scale instead of fp16 and fp16_loss_scale.

I did a training run on the V100's, and got 75.87%. I also did 3 runs with master (I had to use a smaller batch size due to OOM's on P100's) and got a mean of 75.75% and a std of 0.28%. So the use of fp16 does not seem to affect accuracy.

Also @karmel, this PR conflicts with your checkpoint PR because it changes the namespace with the custom getter, so we will need to coordinate those two PR's.

Copy link
Contributor

@karmel karmel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some superb commenting in this PR-- much appreciated, thank you.

return resnet_run_loop.resnet_model_fn(
dtype=params['dtype'],
features=features, labels=labels, mode=mode, model_class=Cifar10Model,
resnet_size=params['resnet_size'],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: line separations here are inconsistent (ln2 versus the rest)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

loss_filter_fn=loss_filter_fn,
multi_gpu=params['multi_gpu'])
return resnet_run_loop.resnet_model_fn(
dtype=params['dtype'],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: TF convention would suggest dtype should be the last kwarg for a function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -351,7 +351,8 @@ def __init__(self, resnet_size, bottleneck, num_classes, num_filters,
kernel_size,
conv_stride, first_pool_size, first_pool_stride,
second_pool_size, second_pool_stride, block_sizes, block_strides,
final_size, version=DEFAULT_VERSION, data_format=None):
final_size, version=DEFAULT_VERSION, data_format=None,
dtype=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this and the child classes above: you should be able to make the default tf.float32, rather than none.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I added a global default and use that.

@@ -418,6 +421,60 @@ def __init__(self, resnet_size, bottleneck, num_classes, num_filters,
self.block_sizes = block_sizes
self.block_strides = block_strides
self.final_size = final_size
self.dtype = dtype or tf.float32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not suggesting we should, but, to consider: do some type-checking on this? If someone passes in np.float32, does this work? If an invalid type is passed in, the resulting error will be cryptic. Maybe we should have an ALLOWED_TYPES somewhere, and validate those?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, yes. It turns out that tensorflow will coerce numpy dtypes into tf dtypes, which would result in very subtle issues. (I'm not even sure it would hard fail.) This seems worthwhile.

*args, **kwargs):
"""Creates variables in fp32, then casts to fp16 if necessary.

This function is a custom getter. A custom getter is a function with the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: indentation not necessary

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Every time. One day I'll learn...

# so small, they underflow to 0. To avoid this, we multiply the loss by
# loss_scale to make these tensor values loss_scales times bigger.
scaled_grad_vars = optimizer.compute_gradients(loss * loss_scale)
unscaled_grad_vars = [(grad / loss_scale, var)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you comment to explain the second step here too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.



def parse_dtype_info(flags):
"""Convert dtype string to tf dtype, and set loss_scale default as needed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests, please.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

self.add_argument(
"--dtype", "-dt",
default="fp32",
choices=["fp16", "float16", "fp32", "float32"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nice, but also 2x the maintenance. Should we just enforce one nomenclature or the other?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine simply because the first thing we do is convert them out of strings. But then again I'd write a 1000 page opus in defense of shaving off a single character from a CLI arg. I'm going to leave it for now, but if you decide to decree I won't fight it.

"but the loss scale helps avoid some intermediate gradients "
"from underflowing to zero. If not provided the default for "
"fp16 is 128 and 1 for all other dtypes.",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy face.

inputs = tf.identity(inputs, 'final_dense')

return inputs
with self._model_variable_scope():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aw, snap. Can you generated new checkpoints and SavedModels for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

@robieta
Copy link
Contributor Author

robieta commented Apr 3, 2018

@karmel I'll address comments specifically shortly, but two points of note:

  1. tf.cast is sometimes a literal no-op. So if layer is a tf.Tensor, then layer is tf.cast(layer, tf.float32) is True. However if it is a tf.SparseTensor of float32's then layer is tf.cast(layer, tf.float32) is False. The reason seems to be that SparseTensors allow heterogeneous dtypes. So we take the tf.cast's out of conditionals if we are sure we won't use SparseTensors. This seems dubious since there are categorical variables. (We could also request tf.cast change it's behavior if ALL elements of a SparseTensor are the requested dtype.)
  2. Rolling a lot of the dtype logic into a util would be cleaner and more generalizable. However I think it's better if that sort of code is provided by TensorFlow. So we need to determine if this is just a short term stopgap so we can do some V100 testing, or if the automated mixed precision code is far enough out to warrant a formal utility in the mean time.

@robieta
Copy link
Contributor Author

robieta commented Apr 3, 2018

Discussed offline.

  1. Just unconditionally cast, and note in a comment that for SparseTensor of fp32's it may still not be a no-op.
  2. Leave mixed precision code in official/resnet in the hope that we can soon rip it out and replace it with mixed precision management from tf proper. If we have a second model where we want fp16 and the tf proper version isn't ready, we can reevaluate moving that code to official/utils

@robieta robieta force-pushed the float16_resnet branch 2 times, most recently from 89f164e to 44f53ff Compare April 4, 2018 17:10
@robieta
Copy link
Contributor Author

robieta commented Apr 4, 2018

@karmel I have addressed your comments.

Copy link
Contributor

@karmel karmel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost there.

@@ -36,6 +36,9 @@
_BATCH_NORM_DECAY = 0.997
_BATCH_NORM_EPSILON = 1e-5
DEFAULT_VERSION = 2
DEFAULT_DTYPE = tf.float32
CASTABLE_TYPES = (tf.float16,)
ALLOWED_TYPES = (tf.float32,) + CASTABLE_TYPES
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: (DEFAULT_DTYPE, ) +...

been called if no custom getter was used. Custom getters typically get a
variable with `getter`, then modify it in some way.

This custom getter will create an fp32 variable. If an low precision
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: a low precision

"float16": tf.float16,
"fp32": tf.float32,
"float32": tf.float32,
}.get(flags.dtype, flags.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second read, I think we should stick with just two options-- fp32, fp16.

Also, let's move this to a module dict, then you can just get DTYPE_MAP.keys() for choices below.

}.get(flags.dtype, flags.dtype)

if flags.dtype is None or isinstance(flags.dtype, str):
raise ValueError("Invalid dtype: {}".format(flags.dtype))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just try to do DTYPE_MAP[flags.dtype] and catch the KeyError? Fewer lines, no isinstance check.

flags.loss_scale = {
"float16": 128,
"float32": 1,
}[flags.dtype.name]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also a constant? Perhaps in the same DTYPE_MAP. Ideally keyed on the same args, rather than a different string drawn from the TF name, which is not in our control.

self.add_argument(
"--dtype", "-dt",
default="fp32",
choices=["fp16", "float16", "fp32", "float32", "int8"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int8 is not actually a choice currently, and will fail with a ValueError above, correct? In any case, a call to .keys() on a constant will keep us in sync.

args = parser.parse_args(["--dtype", dtype_str, "--loss_scale", "5"])
parsers.parse_dtype_info(args)

assert args.loss_scale == 5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test that int8/invalid types raise errors, given the discrepancy currently in choices versus this function.

@robieta
Copy link
Contributor Author

robieta commented Apr 4, 2018

I made the changes. Was able to refactor parse_dtype_info() to still be idempotent, but much cleaner and with everything keyed off of DTYPE_MAP.

Copy link
Contributor

@karmel karmel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another nit or two, but looks good. Can you also post results/tensorboards for the record?

ValueError: If an invalid dtype is provided.
"""
if not (flags.dtype is None or isinstance(flags.dtype, str)):
return # Make function idempotent
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this instead be if dtype is in the set of allowed tf dtypes? Otherwise, you could pass in an int and get through, right? Although maybe that's being too defensive if we assume this only gets called with flags generated by the argparser, which requires a string.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I was being overly defensive because I was afraid of odd behavior in "in" or dict keying. But I now use "in (tf.float16, tf.float32)" elsewhere, so I suppose doing the same here doesn't introduce any additional risk.


flags.loss_scale = (flags.loss_scale if flags.loss_scale else
default_loss_scale)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this can just be = flags.loss_scale or default_loss_scale

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right.

@robieta
Copy link
Contributor Author

robieta commented Apr 4, 2018

Cool. I'll hold off merging until I have checkpoints to go along.

@robieta robieta requested a review from a team as a code owner April 9, 2018 16:36
@robieta robieta merged commit fbb27cf into master Apr 9, 2018
@robieta robieta deleted the float16_resnet branch April 9, 2018 20:08
@jonasrauber
Copy link
Contributor

@robieta Great work! I have a couple of questions regarding the performance:

Currently training is I/O bound; however synthetic data runs confirm the fp16 accelleration to ~4000 images/sec during training.

Currently? Is there any hope that this will not be I/O-bound in the near future?

That’s 4000 images/sec on ImageNet, right? Which GPU? 4000 compared to fp32 with the same setup and on the same GPU?

@robieta
Copy link
Contributor Author

robieta commented Apr 10, 2018

@jonasrauber Hi. #3887 significantly improves performance. There's still work to be done, but the model will be much faster (and no longer I/O bound so far as I can tell) once that gets merged. 4k/sec figure is for 8 Nvidia V100s. fp32 is unsurprisingly exactly half of fp16. Again, I should emphasize that these are very ad-hoc measurements, so don't read too too much into them.

omegafragger pushed a commit to omegafragger/models that referenced this pull request May 15, 2018
* Add fp16 support to resnet.

* address PR comments

* add dtype checking to model definition

* delint

* more PR comments

* few more tweaks

* update resnet checkpoints
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants