Skip to content

shicaiwei123/SDD-CVPR2024

Folders and files

NameName
Last commit message
Last commit date

Latest commit

author
shicaiwei
Apr 1, 2024
e412036 · Apr 1, 2024

History

9 Commits
Apr 1, 2024
Mar 21, 2024
Apr 1, 2024
Mar 21, 2024
Mar 21, 2024
Mar 6, 2024
Mar 24, 2024
Mar 6, 2024
Mar 25, 2024
Mar 24, 2024
Mar 24, 2024
Mar 21, 2024
Mar 21, 2024
Mar 21, 2024
Mar 21, 2024
Mar 21, 2024
Mar 21, 2024
Apr 1, 2024
Mar 21, 2024
Mar 21, 2024
Mar 21, 2024
Mar 21, 2024
Mar 21, 2024
Apr 1, 2024
Mar 24, 2024
Mar 24, 2024
Mar 24, 2024
Mar 24, 2024
Apr 1, 2024

Repository files navigation

SDD-CVPR204

Official code for cvpr2024 paper Scale Decoupled Distillation

Introduction

Framework

Main result

On CIFAR100

On ImageNet

On CUB200

Installation

Environments:

  • Python 3.8
  • PyTorch 1.12.0
  • torchvision 0.13.1

Training on CIFAR-100

  • Fetch the pretrained teacher models by:

    sh fetch_pretrained_teachers.sh
    

    which will download and save the models to save/models

  • Run distillation by following commands in teacher_resnet32x4.sh,teacher_unpair.sh,teacher_vgg.sh, and teacher_wrn.sh. An example of is given by

    python train_origin.py --cfg configs/cifar100/sdd_dkd/res32x4_shuv1.yaml --gpu 1 --M [1,2,4]

Training on ImageNet

Training on CUB200

  • Download the pretrained teacher model in the cub200 folder in baiduyun
  • Mv the 'cub200' folder into the 'save' folder
  • Run the command in train_cub_x.sh

Core code

  • We provide the implement of SD-KD ,SD-DKD, and SD-NKD in KD.py, SDD_DKD.py, and SDD_nkd.py
  • We also provide the modified teacher and studnet in models, with the suffix of SDD

Applied SDD into a new teacher-student pair and logit distillation

  • modified the teacher and student

    • add the SPP module in the teacher and student model,for example the line 128 in mdistiller/cifar100/resnet.py
    self.spp=SPP(M=M)              
    • and then calculate the scaled decoupled logit output with the final feature maps before pooling. For example the line202-209 in mdistiller/cifar100/resnet.py
          x_spp,x_strength = self.spp(x)
    
          x_spp = x_spp.permute((2, 0, 1))
          m, b, c = x_spp.shape[0], x_spp.shape[1], x_spp.shape[2]
          x_spp = torch.reshape(x_spp, (m * b, c))
          patch_score = self.fc(x_spp)
          patch_score = torch.reshape(patch_score, (m, b, self.class_num))
          patch_score = patch_score.permute((1, 2, 0))
  • modified logit distillation

    • convert shape
      • from B X C X N to N*B X C. Here N is the number of decoupled region
    • calculate the distillation loss with vanilla distillation loss
      • only conduct average or sum in the dim of calss and skip the dim of batch
    • find the complementary and consistent local distillation loss and modify the weight of complementary terms
    • example can see the sdd_kd_loss in KD.py

Acknowledgement

Thanks for CRD and DKD. We build this library based on the CRD's codebase and the DKD's codebase

About

Official code for Scale Decoupled Distillation

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published