[NEW!] The compressed model and test codes of GauGAN are released! Check here to use our models.
[NEW!] The tutorial of compression is released! Check the overview for better understanding our codebase.
[NEW!] The PyTorch implementation of a general conditional GAN Compression framework is released.
We introduce GAN Compression, a general-purpose method for compressing conditional GANs. Our method reduces the computation of widely-used conditional GAN models, including pix2pix, CycleGAN, and GauGAN, by 9-21x while preserving the visual fidelity. Our method is effective for a wide range of generator architectures, learning objectives, and both paired and unpaired settings.
GAN Compression: Efficient Architectures for Interactive Conditional GANs
Muyang Li, Ji Lin, Yaoyao Ding, Zhijian Liu, Jun-Yan Zhu, and Song Han
MIT, Adobe Research, SJTU
In CVPR 2020.
GAN Compression framework: ① Given a pre-trained teacher generator G', we distill a smaller “once-for-all” student generator G that contains all possible channel numbers through weight sharing. We choose different channel numbers for the student generator G at each training step. ② We then extract many sub-generators from the “once-for-all” generator and evaluate their performance. No retraining is needed, which is the advantage of the “once-for-all” generator. ③ Finally, we choose the best sub-generator given the compression ratio target and performance target (FID or mAP), perform fine-tuning, and obtain the final compressed model.
PyTorch Colab notebook: CycleGAN and pix2pix.
- Linux
- Python 3
- CPU or NVIDIA GPU + CUDA CuDNN
-
Clone this repo:
git clone [email protected]:mit-han-lab/gan-compression.git cd gan-compression
-
Install PyTorch 1.4 and other dependencies (e.g., torchvision).
- For pip users, please type the command
pip install -r requirements.txt
. - For Conda users, we provide an installation script
scripts/conda_deps.sh
. Alternatively, you can create a new Conda environment usingconda env create -f environment.yml
.
- For pip users, please type the command
-
Install torchprofile.
pip install --upgrade git+https://github.com/mit-han-lab/torchprofile.git
-
Download the CycleGAN dataset (e.g., horse2zebra).
bash datasets/download_cyclegan_dataset.sh horse2zebra
-
Get the statistical information for the ground-truth images for your dataset to compute FID. We provide pre-prepared real statistic information for several datasets. For example,
bash datasets/download_real_stat.sh horse2zebra A bash datasets/download_real_stat.sh horse2zebra B
-
Download the pre-trained models.
python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage full python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage compressed
-
Test the original full model.
bash scripts/cycle_gan/horse2zebra/test_full.sh
-
Test the compressed model.
bash scripts/cycle_gan/horse2zebra/test_compressed.sh
-
Measure the latency of the two models.
bash scripts/cycle_gan/horse2zebra/latency_full.sh bash scripts/cycle_gan/horse2zebra/latency_compressed.sh
-
Download the pix2pix dataset (e.g., edges2shoes).
bash datasets/download_pix2pix_dataset.sh edges2shoes-r
-
Get the statistical information for the ground-truth images for your dataset to compute FID. We provide pre-prepared real statistics for several datasets. For example,
bash datasets/download_real_stat.sh edges2shoes-r B
-
Download the pre-trained models.
python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage full python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage compressed
-
Test the original full model.
bash scripts/pix2pix/edges2shoes-r/test_full.sh
-
Test the compressed model.
bash scripts/pix2pix/edges2shoes-r/test_compressed.sh
-
Measure the latency of the two models.
bash scripts/pix2pix/edges2shoes-r/latency_full.sh bash scripts/pix2pix/edges2shoes-r/latency_compressed.sh
-
Prepare the cityscapes dataset. Check here for preparing the cityscapes dataset.
-
Get the statistical information for the ground-truth images for your dataset to compute FID. We provide pre-prepared real statistics for several datasets. For example,
bash datasets/download_real_stat.sh cityscapes A
-
Download the pre-trained models.
python scripts/download_model.py --model gaugan --task cityscapes --stage full python scripts/download_model.py --model gaugan --task cityscapes --stage compressed
-
Test the original full model.
bash scripts/gaugan/cityscapes/test_full.sh
-
Test the compressed model.
bash scripts/gaugan/cityscapes/test_compressed.sh
-
Measure the latency of the two models.
bash scripts/gaugan/cityscapes/latency_full.sh bash scripts/gaugan/cityscapes/latency_compressed.sh
For the Cityscapes dataset, we cannot provide it due to license issue. Please download the dataset from https://cityscapes-dataset.com and use the script datasets/prepare_cityscapes_dataset.py
to preprocess it. You need to download gtFine_trainvaltest.zip
and leftImg8bit_trainvaltest.zip
and unzip them in the same folder. For example, you may put gtFine
and leftImg8bit
in database/cityscapes-origin
. You need to prepare the dataset with the following commands:
python datasets/get_trainIds.py database/cityscapes-origin/gtFine/
python datasets/prepare_cityscapes_dataset.py \
--gtFine_dir database/cityscapes-origin/gtFine \
--leftImg8bit_dir database/cityscapes-origin/leftImg8bit \
--output_dir database/cityscapes \
--table_path datasets/table.txt
You will get a preprocessed dataset in database/cityscapes
and a mapping table (used to compute mAP) in dataset/table.txt
.
To support mAP computation, you need to download a pre-trained DRN model drn-d-105_ms_cityscapes.pth
from http://go.yf.io/drn-cityscapes-models. By default, we put the drn model in the root directory of our repo. Then you can test our compressed models on cityscapes after you have downloaded our compressed models.
Please refer to our training tutorial on how to train models on our datasets and your own.
To compute the FID score, you need to get some statistical information from the groud-truth images of your dataset. We provide a script get_real_stat.py
to extract statistical information. For example, for the edges2shoes dataset, you could run the following command:
python get_real_stat.py \
--dataroot database/edges2shoes-r \
--output_path real_stat/edges2shoes-r_B.npz \
--direction AtoB
To help users better understand and use our code, we briefly overview the functionality and implementation of each package and each module.
If you use this code for your research, please cite our paper.
@inproceedings{li2020gan,
title={GAN Compression: Efficient Architectures for Interactive Conditional GANs},
author={Li, Muyang and Lin, Ji and Ding, Yaoyao and Liu, Zhijian and Zhu, Jun-Yan and Han, Song},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2020}
}
Our code is developed based on pytorch-CycleGAN-and-pix2pix and SPADE.
We also thank pytorch-fid for FID computation and drn for mAP computation.