Pytorch implementation for paper Rethinking Interactive Image Segmentation with Low Latency, High Quality, and Diverse Prompts, CVPR 2024.
Qin Liu, Jaemin Cho, Mohit Bansal, Marc Niethammer
UNC-Chapel Hill
The code is tested with python=3.10
, torch=2.2.0
, torchvision=0.17.0
.
git clone https://github.com/uncbiag/SegNext
cd SegNext
Now, create a new conda environment and install required packages accordingly.
conda create -n segnext python=3.10
conda activate segnext
conda install pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 pytorch-cuda=11.8 -c pytorch -c nvidia
pip install -r requirements.txt
First, download three model weights: vitb_sax1 (408M), vitb_sax2 (435M), and vitb_sax2_ft (435M). These weights will be automatically saved to the weights
folder.
python download.py
Run interactive GUI with the downloaded weights. The assets
contains images for demo.
./run_demo.sh
We train and test our method on three datasets: DAVIS, COCO+LVIS, and HQSeg-44K.
Dataset | Description | Download Link |
---|---|---|
DAVIS | 345 images with one object each (test) | DAVIS.zip (43 MB) |
HQSeg-44K | 44320 images (train); 1537 images (val) | official site |
COCO+LVIS* | 99k images with 1.5M instances (train) | original LVIS images + combined annotations |
Don't forget to change the paths to the datasets in config.yml after downloading and unpacking.
(*) To prepare COCO+LVIS, you need to download original LVIS v1.0, then download and unpack pre-processed annotations that are obtained by combining COCO and LVIS dataset into the folder with LVIS v1.0. (The combined annotations are prepared by RITM.)
We provide a script (run_eval.sh
) to evaluate our presented models. The following command runs the NoC evaluation on all test datasets.
python ./segnext/scripts/evaluate_model.py --gpus=0 --checkpoint=./weights/vitb_sa2_cocolvis_hq44k_epoch_0.pth --datasets=DAVIS,HQSeg44K
Train Dataset |
Model | HQSeg-44K | DAVIS | ||||||
---|---|---|---|---|---|---|---|---|---|
5-mIoU | NoC90 | NoC95 | NoF95 | 5-mIoU | NoC90 | NoC95 | NoF95 | ||
C+L | vitb-sax1 (408 MB) | 85.41 | 7.47 | 11.94 | 731 | 90.13 | 5.46 | 13.31 | 177 |
C+L | vitb-sax2 (435 MB) | 85.71 | 7.18 | 11.52 | 700 | 89.85 | 5.34 | 12.80 | 163 |
C+L+HQ | vitb-sax2 (435 MB) | 91.75 | 5.32 | 9.42 | 583 | 91.87 | 4.43 | 10.73 | 123 |
For SAT latency evaluation, please refer to eval_sat_latency.ipynb.
We provide a script (run_train.sh
) for training our models on the HQSeg-44K dataset. You can start training with the following commands. By default we use 4 A6000 GPUs for training.
# train vitb-sax1 model on coco+lvis
MODEL_CONFIG=./segnext/models/default/plainvit_base1024_cocolvis_sax1.py
torchrun --nproc-per-node=4 --master-port 29504 ./segnext/train.py ${MODEL_CONFIG} --batch-size=16 --gpus=0,1,2,3
# train vitb-sax2 model on coco+lvis
MODEL_CONFIG=./segnext/models/default/plainvit_base1024_cocolvis_sax2.py
torchrun --nproc-per-node=4 --master-port 29505 ./segnext/train.py ${MODEL_CONFIG} --batch-size=16 --gpus=0,1,2,3
# finetune vitb-sax2 model on hqseg-44k
MODEL_CONFIG=./segnext/models/default/plainvit_base1024_hqseg44k_sax2.py
torchrun --nproc-per-node=4 --master-port 29506 ./segnext/train.py ${MODEL_CONFIG} --batch-size=12 --gpus=0,1,2,3 --weights ./weights/vitb_sa2_cocolvis_epoch_90.pth
@article{liu2024rethinking,
title={Rethinking Interactive Image Segmentation with Low Latency, High Quality, and Diverse Prompts},
author={Liu, Qin and Cho, Jaemin and Bansal, Mohit and Niethammer, Marc},
journal={arXiv preprint arXiv:2404.00741},
year={2024}
}