-
Notifications
You must be signed in to change notification settings - Fork 67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: make flax_resnet.py faster and more accurate across the board #119
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks a lot Fabian!
@mblondel : do you think I should switch the maxiter -> epochs in the related examples for consistency? I'm thinking flax_image_classif.py and perhaps the haiku* ones |
Should be doable since I'm just computing the max_iter necessary to reach X epochs, not changing the structure of the for loop. WRT to the id_print, I agree we should use it in the other examples |
Most changes are geared towards ensuring we give strong and efficient baselines by default, without parameter tuning. The most dramatic gains are in the CIFAR* datasets. Currently, the train/test accuracy for the different datasets with default flags are: - MNIST: 0.99/0.99 (this one didn't change much) - FASHION_MNIST: 0.94/0.88 with a Time elapsed of 0:03:09 (vs 0.71/0.69 for previous defaults). - E_MNIST: 0.85/0.85 with a Time elapsed of 0:25:56 (vs 0.31/0.37 for previous defaults). - CIFAR10: 0.87/0.75 with Time elapsed of 0:02:36 (vs 0.23/0.22 for previous default architecture and 0.29/0.32 for resnet18). - CIFAR100: 0.72/0.39, with a time elapsed: 0:02:31 (vs 0.07/0.06 for previous default architecture and 0.11/0.09 for resnet18). Here, time elapsed is on a workstation with a GeForce GTX 1080 GPU. It will certainly be slower on CPU, although I expect most people aiming to train large resnets to have access to a GPU. More precise, main changes are: * Refactored the printing statements to be able to jit compile the update rule. It now runs very efficiently on GPU. * Replace maxiter by epochs. The maxiter of 100 is nowhere near to give reasonable accuracy on CIFAR10, and I believe that setting the limit in terms of epochs is both more common and easier to set when considering several datasets in the same example. * Use resnet18 as default. resnet1 gives 0.4 test accuracy on cifar10 after 30 epochs, while resnet18 goes up to 0.7. * Print running time. It's been useful for me to compare the efficiency of different approaches.
I added the jitted_update to avoid recompilation. i'll open an issue for the renaming of maxiter-> epochs in other examples |
Most changes are geared towards ensuring we give strong and
efficient baselines by default, without parameter tuning. The most
dramatic gains are in the CIFAR* datasets.
Currently, the train/test accuracy for the different datasets with
default flags are:
(vs 0.71/0.69 for previous defaults).
(vs 0.31/0.37 for previous defaults).
(vs 0.23/0.22 for previous default architecture and
0.29/0.32 for resnet18).
(vs 0.07/0.06 for previous default architecture and
0.11/0.09 for resnet18).
Here, time elapsed is on a workstation with a GeForce GTX 1080 GPU. It
will certainly be slower on CPU, although I expect most people aiming to
train large resnets to have access to a GPU.
More precise, main changes are:
Refactored the printing statements to be able to jit compile the
update rule. It now runs very efficiently on GPU.
Replace maxiter by epochs. The maxiter of 100 is nowhere near
to give reasonable accuracy on CIFAR10, and I believe that setting
the limit in terms of epochs is both more common and easier to set
when considering several datasets in the same example.
Use resnet18 as default. resnet1 gives 0.4 test accuracy on cifar10
after 30 epochs, while resnet18 goes up to 0.7.
Print running time. It's been useful for me to compare the efficiency
of different approaches.