Skip to content

Pytorch implementation of our paper MaxQ: Multi-Axis Query for N:M Sparsity Network accepted by CVPR 2024.

Notifications You must be signed in to change notification settings

JingyangXiang/MaxQ

Folders and files

NameName
Last commit message
Last commit date

Latest commit

author
ZJUTER0126
Mar 12, 2024
68a256b · Mar 12, 2024

History

2 Commits
Mar 5, 2024
Mar 5, 2024
Mar 5, 2024
Mar 5, 2024
Mar 5, 2024
Mar 5, 2024
Mar 12, 2024
Mar 12, 2024
Mar 5, 2024
Mar 5, 2024
Mar 5, 2024
Mar 5, 2024
Mar 5, 2024

Repository files navigation

[CVPR 2024] MaxQ: Multi-Axis Query for N:M Sparsity Network (Paper Link)

Jingyang Xiang, Siqi Li, Junhao Chen, Zhuangzhi Chen, Tianxin Huang, Linpeng Peng, Yong Liu

Pytorch implementation of MaxQ in CVPR 2024.

Main code can be found in ./models/conv_type/mullt_axis_query.py

Introduction

In this paper, we propose an efficient and effective Multi-Axis Query methodology, dubbed as MaxQ, which employs a dynamic approach to generate soft N:M masks, considering the weight importance across multiple axes. Meanwhile, a sparsity strategy that gradually increases the percentage of N:M weight blocks is applied, which allows the network to heal from the pruning-induced damage progressively.

Prepare ImageNet1K

Create a data directory as a base for all datasets. For example, if your base directory is /datadir/datasetthen imagenet would be located at /datadir/dataset/imagenet. You should place train data and val data in /datadir/dataset/imagenet/train and /datadir/dataset/imagenet/val respectively.

Training on ImageNet1K

All scripts can be obtained in ./scripts/generate_scripts.py

python ./scripts/generate_scripts.py
python pruning_train.py [DATA_PATH] --set ImageNet -a [ARCH] \
--no-bn-decay True --save_dir [SAVE_DIR]--warmup-length 0 --N 1 --M 16 --decay 0.0002 --conv-bn-type SoftMaxQConv2DBN \
--weight-decay 0.0001 --nesterov False --workers 16 --increase-start 0 --increase-end 90

Results on ImageNet1K

All models can be obtained in OpenI community. Many thanks to OpenI for the storage space!

name N for N:M Sparsity M for N:M Sparsity training epochs use dali Top-1 Accuracy Top-5 Accuracy model & log
resnet34 1 4 120 74.2 91.7 link
resnet34 2 4 120 74.5 92.1 link
resnet50 1 4 120 77.3 93.4 link
resnet50 1 16 200 75.2 92.6 link
resnet50 2 4 120 77.6 93.7 link
resnet50 2 8 120 77.2 93.5 link
mobilenetv1 1 4 120 70.4 89.5 link
mobilenetv1 2 4 120 72.3 90.8 link
mobilenetv2 1 4 120 67.0 87.5 link
mobilenetv2 2 4 120 69.8 89.3 link
mobilenetv3_small 1 4 120 55.3 78.9 link
mobilenetv3_small 2 4 120 60.8 82.9 link

Testing on ImageNet1K

python pruning_train.py [DATA_PATH] --set ImageNet -a [ARCH] \
--no-bn-decay True --save_dir [SAVE_DIR]--warmup-length 0 --N 1 --M 16 --decay 0.0002 --conv-bn-type SoftMaxQConv2DBN \
--weight-decay 0.0001 --nesterov False --workers 16 --increase-start 0 --increase-end 90 \
--pretrained [PRETRAINED_PATH] --evaluate

Testing log for 2:4 ResNet50 on ImageNet1K

[2024-03-02 14:05:24] Test: [0/782]	Time 6.638 (6.638)	Loss 1.3484 (1.3484)	Prec@1 93.750 (93.750)	Prec@5 98.438 (98.438)
[2024-03-02 14:05:36] Test: [100/782]	Time 0.118 (0.187)	Loss 1.4476 (1.6217)	Prec@1 90.625 (83.261)	Prec@5 96.875 (95.699)
[2024-03-02 14:05:49] Test: [200/782]	Time 0.116 (0.155)	Loss 1.6416 (1.6144)	Prec@1 85.938 (83.225)	Prec@5 96.875 (96.183)
[2024-03-02 14:06:01] Test: [300/782]	Time 0.128 (0.145)	Loss 1.5182 (1.6168)	Prec@1 84.375 (82.942)	Prec@5 95.312 (96.288)
[2024-03-02 14:06:13] Test: [400/782]	Time 0.117 (0.140)	Loss 1.5314 (1.7098)	Prec@1 81.250 (80.556)	Prec@5 98.438 (95.242)
[2024-03-02 14:06:26] Test: [500/782]	Time 0.118 (0.137)	Loss 1.3581 (1.7604)	Prec@1 90.625 (79.416)	Prec@5 98.438 (94.633)
[2024-03-02 14:06:38] Test: [600/782]	Time 0.131 (0.135)	Loss 1.6157 (1.8020)	Prec@1 87.500 (78.557)	Prec@5 96.875 (94.169)
[2024-03-02 14:06:51] Test: [700/782]	Time 0.117 (0.133)	Loss 1.8406 (1.8357)	Prec@1 79.688 (77.739)	Prec@5 95.312 (93.741)
[2024-03-02 14:07:01]  * Prec@1 77.576 Prec@5 93.708 Error@1 22.424

Optional arguments

optional arguments:
    # misc
    --save_dir                  Path to save directory
    
    # for model                         
    --arch                      Choose model
                                default: resnet18
                                choice: ['resnet18', 'resnet34', 'resnet50', 'mobilenet_v1', 'mobilenet_v2', 'mobilenet_v3_small', 'mobilenet_v3_large']
    --conv-bn-type              convbn type for network
                                default: SoftMaxQConv2DBN
      
    # for datatset
    data                        Path to dataset
    --set                       Choose dataset
                                default: ImageNet
                                choice: ["ImageNet", "ImageNetDali"]   
                                                 
    # for pretrain, resume or evaluate
    --evaluate                  Evaluate model on validation set
    --pretrained                Path to pretrained checkpoint

    # N:M sparsity                            
    --N                         N for N:M sparsity 
                                default: 2  
    --M                         M for N:M sparsity 
                                default: 4
    --decay                     decay for SR-STE method
                                default: 0.0002
    --decay-type                decay type for conv type
                                default: v1

    # MaxQ method                            
    --increase-start            Start epoch to increase ratio of N:M blocks
                                default: 0
    --increase-end              End epoch to increase ratio of N:M blocks
                                default: 90
    --tau                       Tau for MaxQ method
                                default: 0.01
    --prune-schedule            Prune scheduler for incremental sparsity in MaxQ method

Dependencies

  • Python 3.9.16
  • Pytorch 2.0.0
  • Torchvision 0.15.1
  • nvidia-dali-nightly-cuda110 1.27.0.dev20230531
  • nvidia-dali-tf-plugin-nightly-cuda110 1.27.0.dev20230531

THANKS

Special thanks to the authors and contributors of the following projects:

About

Pytorch implementation of our paper MaxQ: Multi-Axis Query for N:M Sparsity Network accepted by CVPR 2024.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages