-
Notifications
You must be signed in to change notification settings - Fork 45.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add fp16 support to official ResNet. #3687
Conversation
Quick thought before a more full review: there is a future in which we support a number of different quantization options, at least for some models. Is it feasible generalize this to a dtype being passed around rather than fixing to only fp16? You don't need to handle multiple dtypes for now, but at least not fix all the params and names to "fp16". |
yes, that is how the other platforms handle it. I doubt we would handle
int8 but you never know and maybe. passing the dtype is popular.
…On Wed, Mar 21, 2018 at 5:31 PM Karmel Allison ***@***.***> wrote:
Quick thought before a more full review: there is a future in which we
support a number of different quantization options, at least for some
models. Is it feasible generalize this to a dtype being passed around
rather than fixing to only fp16? You don't need to handle multiple dtypes
for now, but at least not fix all the params and names to "fp16".
—
You are receiving this because your review was requested.
Reply to this email directly, view it on GitHub
<#3687 (comment)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AWZesniQ33UDnkGyMoAtMen3BzMDG-qkks5tgvDLgaJpZM4S1tIc>
.
|
My understanding is that once automatic fp16 scaling is ready all of this will be ripped out, and at that point passing dtype will be natural. If you're worried about the interface changing it's no trouble to have dtype as the CLI flag. The only thing that might be weird is loss_scaling, as that would still have to be an fp16 CLI flag. |
fp16 today, int8 tomorrow. Let's generalize to dtype. Then allow for a loss_scale to be passed in for the selected dtype. If not passed in, select from a dict of defaults based on the selected dtype? Open to pros and cons on that, but I think we should anticipate more dtypes rather than fixing on the one currently ready. |
+1 to karmel's idea. That matches or is better than what I have seen from
other platforms.
…On Fri, Mar 23, 2018 at 11:28 AM Karmel Allison ***@***.***> wrote:
fp16 today, int8 tomorrow. Let's generalize to dtype. Then allow for a
loss_scale to be passed in for the selected dtype. If not passed in, select
from a dict of defaults based on the selected dtype? Open to pros and cons
on that, but I think we should anticipate more dtypes rather than fixing on
the one currently ready.
—
You are receiving this because your review was requested.
Reply to this email directly, view it on GitHub
<#3687 (comment)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AWZesvt3euTht7pQJx4lgVOKmWsgHJqbks5thT7BgaJpZM4S1tIc>
.
|
0d5ec60
to
f102f82
Compare
The interface now uses dtype and loss_scale instead of fp16 and fp16_loss_scale. I did a training run on the V100's, and got 75.87%. I also did 3 runs with master (I had to use a smaller batch size due to OOM's on P100's) and got a mean of 75.75% and a std of 0.28%. So the use of fp16 does not seem to affect accuracy. Also @karmel, this PR conflicts with your checkpoint PR because it changes the namespace with the custom getter, so we will need to coordinate those two PR's. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some superb commenting in this PR-- much appreciated, thank you.
official/resnet/cifar10_main.py
Outdated
return resnet_run_loop.resnet_model_fn( | ||
dtype=params['dtype'], | ||
features=features, labels=labels, mode=mode, model_class=Cifar10Model, | ||
resnet_size=params['resnet_size'], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: line separations here are inconsistent (ln2 versus the rest)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
official/resnet/cifar10_main.py
Outdated
loss_filter_fn=loss_filter_fn, | ||
multi_gpu=params['multi_gpu']) | ||
return resnet_run_loop.resnet_model_fn( | ||
dtype=params['dtype'], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: TF convention would suggest dtype should be the last kwarg for a function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
official/resnet/resnet_model.py
Outdated
@@ -351,7 +351,8 @@ def __init__(self, resnet_size, bottleneck, num_classes, num_filters, | |||
kernel_size, | |||
conv_stride, first_pool_size, first_pool_stride, | |||
second_pool_size, second_pool_stride, block_sizes, block_strides, | |||
final_size, version=DEFAULT_VERSION, data_format=None): | |||
final_size, version=DEFAULT_VERSION, data_format=None, | |||
dtype=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this and the child classes above: you should be able to make the default tf.float32, rather than none.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. I added a global default and use that.
official/resnet/resnet_model.py
Outdated
@@ -418,6 +421,60 @@ def __init__(self, resnet_size, bottleneck, num_classes, num_filters, | |||
self.block_sizes = block_sizes | |||
self.block_strides = block_strides | |||
self.final_size = final_size | |||
self.dtype = dtype or tf.float32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not suggesting we should, but, to consider: do some type-checking on this? If someone passes in np.float32, does this work? If an invalid type is passed in, the resulting error will be cryptic. Maybe we should have an ALLOWED_TYPES somewhere, and validate those?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, yes. It turns out that tensorflow will coerce numpy dtypes into tf dtypes, which would result in very subtle issues. (I'm not even sure it would hard fail.) This seems worthwhile.
official/resnet/resnet_model.py
Outdated
*args, **kwargs): | ||
"""Creates variables in fp32, then casts to fp16 if necessary. | ||
|
||
This function is a custom getter. A custom getter is a function with the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: indentation not necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Every time. One day I'll learn...
official/resnet/resnet_run_loop.py
Outdated
# so small, they underflow to 0. To avoid this, we multiply the loss by | ||
# loss_scale to make these tensor values loss_scales times bigger. | ||
scaled_grad_vars = optimizer.compute_gradients(loss * loss_scale) | ||
unscaled_grad_vars = [(grad / loss_scale, var) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you comment to explain the second step here too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
|
||
def parse_dtype_info(flags): | ||
"""Convert dtype string to tf dtype, and set loss_scale default as needed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests, please.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added.
self.add_argument( | ||
"--dtype", "-dt", | ||
default="fp32", | ||
choices=["fp16", "float16", "fp32", "float32"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is nice, but also 2x the maintenance. Should we just enforce one nomenclature or the other?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fine simply because the first thing we do is convert them out of strings. But then again I'd write a 1000 page opus in defense of shaving off a single character from a CLI arg. I'm going to leave it for now, but if you decide to decree I won't fight it.
"but the loss scale helps avoid some intermediate gradients " | ||
"from underflowing to zero. If not provided the default for " | ||
"fp16 is 128 and 1 for all other dtypes.", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy face.
inputs = tf.identity(inputs, 'final_dense') | ||
|
||
return inputs | ||
with self._model_variable_scope(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aw, snap. Can you generated new checkpoints and SavedModels for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure.
@karmel I'll address comments specifically shortly, but two points of note:
|
Discussed offline.
|
89f164e
to
44f53ff
Compare
@karmel I have addressed your comments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost there.
official/resnet/resnet_model.py
Outdated
@@ -36,6 +36,9 @@ | |||
_BATCH_NORM_DECAY = 0.997 | |||
_BATCH_NORM_EPSILON = 1e-5 | |||
DEFAULT_VERSION = 2 | |||
DEFAULT_DTYPE = tf.float32 | |||
CASTABLE_TYPES = (tf.float16,) | |||
ALLOWED_TYPES = (tf.float32,) + CASTABLE_TYPES |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: (DEFAULT_DTYPE, ) +...
official/resnet/resnet_model.py
Outdated
been called if no custom getter was used. Custom getters typically get a | ||
variable with `getter`, then modify it in some way. | ||
|
||
This custom getter will create an fp32 variable. If an low precision |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: a low precision
"float16": tf.float16, | ||
"fp32": tf.float32, | ||
"float32": tf.float32, | ||
}.get(flags.dtype, flags.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On second read, I think we should stick with just two options-- fp32, fp16.
Also, let's move this to a module dict, then you can just get DTYPE_MAP.keys() for choices below.
}.get(flags.dtype, flags.dtype) | ||
|
||
if flags.dtype is None or isinstance(flags.dtype, str): | ||
raise ValueError("Invalid dtype: {}".format(flags.dtype)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just try to do DTYPE_MAP[flags.dtype] and catch the KeyError? Fewer lines, no isinstance check.
flags.loss_scale = { | ||
"float16": 128, | ||
"float32": 1, | ||
}[flags.dtype.name] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also a constant? Perhaps in the same DTYPE_MAP. Ideally keyed on the same args, rather than a different string drawn from the TF name, which is not in our control.
self.add_argument( | ||
"--dtype", "-dt", | ||
default="fp32", | ||
choices=["fp16", "float16", "fp32", "float32", "int8"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
int8 is not actually a choice currently, and will fail with a ValueError above, correct? In any case, a call to .keys() on a constant will keep us in sync.
args = parser.parse_args(["--dtype", dtype_str, "--loss_scale", "5"]) | ||
parsers.parse_dtype_info(args) | ||
|
||
assert args.loss_scale == 5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test that int8/invalid types raise errors, given the discrepancy currently in choices versus this function.
I made the changes. Was able to refactor parse_dtype_info() to still be idempotent, but much cleaner and with everything keyed off of DTYPE_MAP. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another nit or two, but looks good. Can you also post results/tensorboards for the record?
ValueError: If an invalid dtype is provided. | ||
""" | ||
if not (flags.dtype is None or isinstance(flags.dtype, str)): | ||
return # Make function idempotent |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this instead be if dtype is in the set of allowed tf dtypes? Otherwise, you could pass in an int and get through, right? Although maybe that's being too defensive if we assume this only gets called with flags generated by the argparser, which requires a string.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I was being overly defensive because I was afraid of odd behavior in "in" or dict keying. But I now use "in (tf.float16, tf.float32)" elsewhere, so I suppose doing the same here doesn't introduce any additional risk.
|
||
flags.loss_scale = (flags.loss_scale if flags.loss_scale else | ||
default_loss_scale) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this can just be = flags.loss_scale or default_loss_scale
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right.
Cool. I'll hold off merging until I have checkpoints to go along. |
@robieta Great work! I have a couple of questions regarding the performance:
Currently? Is there any hope that this will not be I/O-bound in the near future? That’s 4000 images/sec on ImageNet, right? Which GPU? 4000 compared to fp32 with the same setup and on the same GPU? |
@jonasrauber Hi. #3887 significantly improves performance. There's still work to be done, but the model will be much faster (and no longer I/O bound so far as I can tell) once that gets merged. 4k/sec figure is for 8 Nvidia V100s. fp32 is unsurprisingly exactly half of fp16. Again, I should emphasize that these are very ad-hoc measurements, so don't read too too much into them. |
* Add fp16 support to resnet. * address PR comments * add dtype checking to model definition * delint * more PR comments * few more tweaks * update resnet checkpoints
This PR adds fp16 support to official/resnet. The vast majority of the work was already done by Reed a month ago (including wonderful comments), and this just rolls those changes into the current master.
Currently training is I/O bound; however synthetic data runs confirm the fp16 accelleration to ~4000 images/sec during training.
I will update when I have run results.