-
Notifications
You must be signed in to change notification settings - Fork 45.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
official/mnist: Use tf.keras.Sequential #3942
official/mnist: Use tf.keras.Sequential #3942
Conversation
official/mnist/mnist.py
Outdated
Returns: | ||
A tf.keras.Model. | ||
""" | ||
input_shape = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably can skip this here since it will be declared by the if-else below anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
official/mnist/mnist.py
Outdated
|
||
L = tf.keras.layers | ||
max_pool = L.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format) | ||
return tf.keras.Sequential([ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is much more readable than the previous version. Thanks.
Please also fix the lint error. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Linted :). Please take another look.
official/mnist/mnist.py
Outdated
5, | ||
padding='same', | ||
data_format=data_format, | ||
activation=tf.nn.relu), max_pool, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: max_pool gets lost at the end of the other function call. Can you move to its own line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
data_format: Either 'channels_first' or 'channels_last'. 'channels_first' is | ||
typically faster on GPUs while 'channels_last' is typically faster on | ||
CPUs. See | ||
https://www.tensorflow.org/performance/performance_guide#data_formats |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is neat. Anticipating that people are going to ask how this is different than the Resnet/call style, and why we use tf.keras.layers here, can you add comments explaining what Sequential does? And, for our reference-- should we be using tf.keras.layers at this point? Sequential in the cases where it makes sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sequential
is a sub-class of Model
so this isn't really different, it's just simpler. If your network is really just a chain of layers, Sequential
saves some boilerplate.
I'm a bit hesitant to add a comment here explaining that since that comment would either have to pre-suppose that model-subclassing is the most typical usage (which it isn't), or first explain that there is another more-verbose way to write this network and then explain how Sequential
is simpler :). But perhaps I'm overthinking, so if you have a suggestion of what the comment should be - happy to add it.
Yes, in other samples we should use Sequential
wherever it fits, and use tf.keras.layers
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"A subclass of tf.keras.Model, tf.keras.Sequential, is returned here, as the MNIST model is just a chain of layers." Or something to that effect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
input_shape = [1, 28, 28] | ||
else: | ||
assert data_format == 'channels_last' | ||
input_shape = [28, 28, 1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be worth commenting that data_format is different between tf.layers and tf.keras.layers. I would have missed that and been confused if trying this+other models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, it isn't different between tf.layers
and tf.keras.layers
. It's just that the Reshape
layer's target_shape
argument excludes the batch dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see. SG.
No description provided.