[CVPR 2024] MaxQ: Multi-Axis Query for N:M Sparsity Network (Paper Link)
Jingyang Xiang, Siqi Li, Junhao Chen, Zhuangzhi Chen, Tianxin Huang, Linpeng Peng, Yong Liu
Pytorch implementation of MaxQ in CVPR 2024.
Main code can be found in ./models/conv_type/mullt_axis_query.py
In this paper, we propose an efficient and effective Multi-Axis Query methodology, dubbed as MaxQ, which employs a dynamic approach to generate soft N:M masks, considering the weight importance across multiple axes. Meanwhile, a sparsity strategy that gradually increases the percentage of N:M weight blocks is applied, which allows the network to heal from the pruning-induced damage progressively.
Create a data directory as a base for all datasets.
For example, if your base directory is /datadir/dataset
then imagenet would be located at /datadir/dataset/imagenet
.
You should place train data and val data in /datadir/dataset/imagenet/train
and /datadir/dataset/imagenet/val
respectively.
All scripts can be obtained in ./scripts/generate_scripts.py
python ./scripts/generate_scripts.py
python pruning_train.py [DATA_PATH] --set ImageNet -a [ARCH] \
--no-bn-decay True --save_dir [SAVE_DIR]--warmup-length 0 --N 1 --M 16 --decay 0.0002 --conv-bn-type SoftMaxQConv2DBN \
--weight-decay 0.0001 --nesterov False --workers 16 --increase-start 0 --increase-end 90
All models can be obtained in OpenI community. Many thanks to OpenI for the storage space!
name | N for N:M Sparsity | M for N:M Sparsity | training epochs | use dali | Top-1 Accuracy | Top-5 Accuracy | model & log |
---|---|---|---|---|---|---|---|
resnet34 | 1 | 4 | 120 | ✘ | 74.2 | 91.7 | link |
resnet34 | 2 | 4 | 120 | ✘ | 74.5 | 92.1 | link |
resnet50 | 1 | 4 | 120 | ✘ | 77.3 | 93.4 | link |
resnet50 | 1 | 16 | 200 | ✘ | 75.2 | 92.6 | link |
resnet50 | 2 | 4 | 120 | ✘ | 77.6 | 93.7 | link |
resnet50 | 2 | 8 | 120 | ✘ | 77.2 | 93.5 | link |
mobilenetv1 | 1 | 4 | 120 | ✘ | 70.4 | 89.5 | link |
mobilenetv1 | 2 | 4 | 120 | ✘ | 72.3 | 90.8 | link |
mobilenetv2 | 1 | 4 | 120 | ✘ | 67.0 | 87.5 | link |
mobilenetv2 | 2 | 4 | 120 | ✘ | 69.8 | 89.3 | link |
mobilenetv3_small | 1 | 4 | 120 | ✘ | 55.3 | 78.9 | link |
mobilenetv3_small | 2 | 4 | 120 | ✘ | 60.8 | 82.9 | link |
python pruning_train.py [DATA_PATH] --set ImageNet -a [ARCH] \
--no-bn-decay True --save_dir [SAVE_DIR]--warmup-length 0 --N 1 --M 16 --decay 0.0002 --conv-bn-type SoftMaxQConv2DBN \
--weight-decay 0.0001 --nesterov False --workers 16 --increase-start 0 --increase-end 90 \
--pretrained [PRETRAINED_PATH] --evaluate
[2024-03-02 14:05:24] Test: [0/782] Time 6.638 (6.638) Loss 1.3484 (1.3484) Prec@1 93.750 (93.750) Prec@5 98.438 (98.438)
[2024-03-02 14:05:36] Test: [100/782] Time 0.118 (0.187) Loss 1.4476 (1.6217) Prec@1 90.625 (83.261) Prec@5 96.875 (95.699)
[2024-03-02 14:05:49] Test: [200/782] Time 0.116 (0.155) Loss 1.6416 (1.6144) Prec@1 85.938 (83.225) Prec@5 96.875 (96.183)
[2024-03-02 14:06:01] Test: [300/782] Time 0.128 (0.145) Loss 1.5182 (1.6168) Prec@1 84.375 (82.942) Prec@5 95.312 (96.288)
[2024-03-02 14:06:13] Test: [400/782] Time 0.117 (0.140) Loss 1.5314 (1.7098) Prec@1 81.250 (80.556) Prec@5 98.438 (95.242)
[2024-03-02 14:06:26] Test: [500/782] Time 0.118 (0.137) Loss 1.3581 (1.7604) Prec@1 90.625 (79.416) Prec@5 98.438 (94.633)
[2024-03-02 14:06:38] Test: [600/782] Time 0.131 (0.135) Loss 1.6157 (1.8020) Prec@1 87.500 (78.557) Prec@5 96.875 (94.169)
[2024-03-02 14:06:51] Test: [700/782] Time 0.117 (0.133) Loss 1.8406 (1.8357) Prec@1 79.688 (77.739) Prec@5 95.312 (93.741)
[2024-03-02 14:07:01] * Prec@1 77.576 Prec@5 93.708 Error@1 22.424
optional arguments:
# misc
--save_dir Path to save directory
# for model
--arch Choose model
default: resnet18
choice: ['resnet18', 'resnet34', 'resnet50', 'mobilenet_v1', 'mobilenet_v2', 'mobilenet_v3_small', 'mobilenet_v3_large']
--conv-bn-type convbn type for network
default: SoftMaxQConv2DBN
# for datatset
data Path to dataset
--set Choose dataset
default: ImageNet
choice: ["ImageNet", "ImageNetDali"]
# for pretrain, resume or evaluate
--evaluate Evaluate model on validation set
--pretrained Path to pretrained checkpoint
# N:M sparsity
--N N for N:M sparsity
default: 2
--M M for N:M sparsity
default: 4
--decay decay for SR-STE method
default: 0.0002
--decay-type decay type for conv type
default: v1
# MaxQ method
--increase-start Start epoch to increase ratio of N:M blocks
default: 0
--increase-end End epoch to increase ratio of N:M blocks
default: 90
--tau Tau for MaxQ method
default: 0.01
--prune-schedule Prune scheduler for incremental sparsity in MaxQ method
- Python 3.9.16
- Pytorch 2.0.0
- Torchvision 0.15.1
- nvidia-dali-nightly-cuda110 1.27.0.dev20230531
- nvidia-dali-tf-plugin-nightly-cuda110 1.27.0.dev20230531
Special thanks to the authors and contributors of the following projects: