Skip to content

Commit bb9d2fd

Browse files
committed
add reparameterized version
1 parent 0f28744 commit bb9d2fd

13 files changed

+113
-59
lines changed

README.md

+13-1
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,19 @@ chmod +x tools/dist_test.sh
170170

171171
## Fine-tuning YOLO-World
172172

173-
We provide the details about fine-tuning YOLO-World in [docs/fine-tuning](./docs/finetuning.md).
173+
<div align="center">
174+
<img src="./assets/finetune_yoloworld.png" width=800px>
175+
</div>
176+
177+
178+
YOLO-World supports **zero-shot inference**, and three types of **fine-tuning recipes**: **(1) normal fine-tuning**, **(2) prompt tuning**, and **(3) reparameterized fine-tuning**.
179+
180+
181+
* Normal Fine-tuning: we provide the details about fine-tuning YOLO-World in [docs/fine-tuning](./docs/finetuning.md).
182+
183+
* Prompt Tuning: we provide more details ahout prompt tuning in [docs/prompt_yolo_world](./docs/prompt_yolo_world.md).
184+
185+
* Reparameterized Fine-tuning: the reparameterized YOLO-World is more suitable for specific domains far from generic scenes. You can find more details in [`docs/reparameterize`](./docs/reparameterize.md).
174186

175187
## Deployment
176188

assets/finetune_yoloworld.png

466 KB
Loading

assets/reparameterize.png

62.8 KB
Loading

configs/finetune_coco/README.md

+4-2
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@ BTW, the COCO fine-tuning results are updated with higher performance (with `mas
1818
| [YOLO-World-v2-S](./yolo_world_v2_s_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py) | AdamW, 2e-4, 80e | ✔️ | ✖️ | 37.5 | 46.1 | 62.0 | 49.9 | [HF Checkpoints](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_s_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_ep80-492dc329.pth) | [log](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_s_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_20240327_110411.log) |
1919
| [YOLO-World-v2-M](./yolo_world_v2_m_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py) | AdamW, 2e-4, 80e | ✔️ | ✖️ | 42.8 | 51.0 | 67.5 | 55.2 | [HF Checkpoints](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_m_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_ep80-69c27ac7.pth) | [log](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_m_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_20240327_110411.log) |
2020
| [YOLO-World-v2-L](./yolo_world_v2_l_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py) | AdamW, 2e-4, 80e | ✔️ | ✖️ | 45.1 | 53.9 | 70.9 | 58.8 | [HF Checkpoints](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_l_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_ep80-81c701ee.pth) | [log](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_l_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_20240326_160313.log) |
21-
| [YOLO-World-v2-L](./yolo_world_v2_l_efficient_neck_2e-4_80e_8gpus_mask-refine_finetune_coco.py) | AdamW, 2e-4, 80e | ✔️ | ✔️ | 45.1 | | | | [HF Checkpoints]() | [log]() |
2221
| [YOLO-World-v2-X](./yolo_world_v2_x_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py) | AdamW, 2e-4, 80e | ✔️ | ✖️ | 46.8 | 54.7 | 71.6 | 59.6 | [HF Checkpoints](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_x_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_ep80-76bc0cbd.pth) | [log](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_x_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco_20240322_181232.log) |
2322
| [YOLO-World-v2-L](./yolo_world_v2_l_vlpan_bn_sgd_1e-3_40e_8gpus_finetune_coco.py) 🔥 | SGD, 1e-3, 40e | ✖️ | ✖️ | 45.1 | 52.8 | 69.5 | 57.8 | [HF Checkpoints](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_l_vlpan_bn_sgd_1e-3_40e_8gpus_finetune_coco_ep80-e1288152.pth) | [log](https://huggingface.co/wondervictor/YOLO-World/blob/main/yolo_world_v2_l_vlpan_bn_sgd_1e-3_40e_8gpus_finetuning_coco_20240327_014902.log) |
2423

2524

25+
### Reparameterized Training
2626

2727

28-
### Reparameterized Training
28+
| model | Schedule | `mask-refine` | efficient neck | AP<sup>ZS</sup>| AP | AP<sub>50</sub> | AP<sub>75</sub> | weights | log |
29+
| :---- | :-------: | :----------: |:-------------: | :------------: | :-: | :--------------:| :-------------: |:------: | :-: |
30+
| [YOLO-World-v2-S](./yolo_world_v2_s_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py) | AdamW, 2e-4, 80e | ✔️ | ✖️ | 37.5 | 46.3 | 62.8 | 50.4 | [HF Checkpoints]() | [log]() |

configs/finetune_coco/yolo_world_v2_l_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py configs/finetune_coco/yolo_world_v2_s_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py

+9-20
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
_base_ = ('../../third_party/mmyolo/configs/yolov8/'
2-
'yolov8_l_mask-refine_syncbn_fast_8xb16-500e_coco.py')
2+
'yolov8_s_mask-refine_syncbn_fast_8xb16-500e_coco.py')
33
custom_imports = dict(imports=['yolo_world'], allow_failed_imports=False)
44

55
# hyper-parameters
@@ -11,11 +11,13 @@
1111
text_channels = 512
1212
neck_embed_channels = [128, 256, _base_.last_stage_out_channels // 2]
1313
neck_num_heads = [4, 8, _base_.last_stage_out_channels // 2 // 32]
14-
base_lr = 2e-3
14+
base_lr = 2e-4
1515
weight_decay = 0.05
1616
train_batch_size_per_gpu = 16
17-
load_from = '/group/40034/adriancheng/notebooks/rep_models/yolo_world_v2_x_obj365v1_goldg_cc3mlite_pretrain_1280ft-14996a36_repconv.pth'
17+
load_from = '../FastDet/output_models/pretrain_yolow-v8_s_clipv2_frozen_te_noprompt_t2i_bn_2e-3adamw_scale_lr_wd_32xb16-100e_obj365v1_goldg_cc3mram250k_train_lviseval-e3592307_rep_conv.pth'
1818
persistent_workers = False
19+
mixup_prob = 0.15
20+
copypaste_prob = 0.3
1921

2022
# model settings
2123
model = dict(type='SimpleYOLOWorldDetector',
@@ -28,14 +30,12 @@
2830
type='MultiModalYOLOBackbone',
2931
text_model=None,
3032
image_model={{_base_.model.backbone}},
31-
frozen_stages=4,
3233
with_text_model=False),
3334
neck=dict(type='YOLOWorldPAFPN',
34-
guide_channels=num_classes,
35+
guide_channels=text_channels,
3536
embed_channels=neck_embed_channels,
3637
num_heads=neck_num_heads,
37-
block_cfg=dict(type='RepConvMaxSigmoidCSPLayerWithTwoConv',
38-
guide_channels=num_classes)),
38+
block_cfg=dict(type='EfficientCSPLayerWithTwoConv')),
3939
bbox_head=dict(head_module=dict(type='RepYOLOWorldHeadModule',
4040
embed_dims=text_channels,
4141
num_guide=num_classes,
@@ -53,7 +53,7 @@
5353
img_scale=_base_.img_scale,
5454
pad_val=114.0,
5555
pre_transform=_base_.pre_transform),
56-
dict(type='YOLOv5CopyPaste', prob=_base_.copypaste_prob),
56+
dict(type='YOLOv5CopyPaste', prob=copypaste_prob),
5757
dict(
5858
type='YOLOv5RandomAffine',
5959
max_rotate_degree=0.0,
@@ -69,7 +69,7 @@
6969
train_pipeline = [
7070
*_base_.pre_transform, *mosaic_affine_transform,
7171
dict(type='YOLOv5MixUp',
72-
prob=_base_.mixup_prob,
72+
prob=mixup_prob,
7373
pre_transform=[*_base_.pre_transform, *mosaic_affine_transform]),
7474
*_base_.last_transform[:-1], *final_transform
7575
]
@@ -135,16 +135,6 @@
135135
lr=base_lr,
136136
weight_decay=weight_decay,
137137
batch_size_per_gpu=train_batch_size_per_gpu),
138-
paramwise_cfg=dict(bias_decay_mult=0.0,
139-
norm_decay_mult=0.0,
140-
custom_keys={
141-
'backbone.text_model':
142-
dict(lr_mult=0.01),
143-
'logit_scale':
144-
dict(weight_decay=0.0),
145-
'embeddings':
146-
dict(weight_decay=0.0)
147-
}),
148138
constructor='YOLOWv5OptimizerConstructor')
149139

150140
# evaluation settings
@@ -153,4 +143,3 @@
153143
proposal_nums=(100, 1, 10),
154144
ann_file='data/coco/annotations/instances_val2017.json',
155145
metric='bbox')
156-
find_unused_parameters = True

configs/finetune_coco/yolo_world_v2_s_rep_efficient_vlpan_sgd_1e-3_80e_8gpus_mask-refine_finetune_coco.py configs/finetune_coco/yolo_world_v2_s_rep_efficient_vlpan_sgd_2e-3_80e_8gpus_mask-refine_finetune_coco.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
text_channels = 512
1212
neck_embed_channels = [128, 256, _base_.last_stage_out_channels // 2]
1313
neck_num_heads = [4, 8, _base_.last_stage_out_channels // 2 // 32]
14-
base_lr = 1e-3
14+
base_lr = 2e-3
1515
weight_decay = 0.0005
1616
train_batch_size_per_gpu = 16
1717
load_from = '../FastDet/output_models/yolo_world_s_clip_t2i_bn_2e-3adamw_32xb16-100e_obj365v1_goldg_train-55b943ea_rep_conv.pth'

configs/finetune_coco/yolo_world_v2_s_rep_vlpan_bn_sgd_1e-3_80e_8gpus_mask-refine_finetune_coco.py configs/finetune_coco/yolo_world_v2_s_rep_efficient_vlpan_sgd_5e-4_80e_8gpus_mask-refine_finetune_coco.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
text_channels = 512
1212
neck_embed_channels = [128, 256, _base_.last_stage_out_channels // 2]
1313
neck_num_heads = [4, 8, _base_.last_stage_out_channels // 2 // 32]
14-
base_lr = 1e-3
14+
base_lr = 5e-4
1515
weight_decay = 0.0005
1616
train_batch_size_per_gpu = 16
1717
load_from = '../FastDet/output_models/yolo_world_s_clip_t2i_bn_2e-3adamw_32xb16-100e_obj365v1_goldg_train-55b943ea_rep_conv.pth'
@@ -32,11 +32,10 @@
3232
image_model={{_base_.model.backbone}},
3333
with_text_model=False),
3434
neck=dict(type='YOLOWorldPAFPN',
35-
guide_channels=num_classes,
35+
guide_channels=text_channels,
3636
embed_channels=neck_embed_channels,
3737
num_heads=neck_num_heads,
38-
block_cfg=dict(type='RepConvMaxSigmoidCSPLayerWithTwoConv',
39-
guide_channels=num_classes)),
38+
block_cfg=dict(type='EfficientCSPLayerWithTwoConv')),
4039
bbox_head=dict(head_module=dict(type='RepYOLOWorldHeadModule',
4140
embed_dims=text_channels,
4241
num_guide=num_classes,
@@ -140,7 +139,6 @@
140139
batch_size_per_gpu=train_batch_size_per_gpu),
141140
constructor='YOLOWv5OptimizerConstructor')
142141

143-
144142
# evaluation settings
145143
val_evaluator = dict(_delete_=True,
146144
type='mmdet.CocoMetric',

configs/finetune_coco/yolo_world_v2_x_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py configs/finetune_coco/yolo_world_v2_x_rep_efficient_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py

+11-20
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
_base_ = ('../../third_party/mmyolo/configs/yolov8/'
2-
'yolov8_x_mask-refine_syncbn_fast_8xb16-500e_coco.py')
3-
custom_imports = dict(imports=['yolo_world'], allow_failed_imports=False)
1+
_base_ = (
2+
'../../third_party/mmyolo/configs/yolov8/'
3+
'yolov8_x_mask-refine_syncbn_fast_8xb16-500e_coco.py')
4+
custom_imports = dict(
5+
imports=['yolo_world'],
6+
allow_failed_imports=False)
47

58
# hyper-parameters
69
num_classes = 80
@@ -11,12 +14,13 @@
1114
text_channels = 512
1215
neck_embed_channels = [128, 256, _base_.last_stage_out_channels // 2]
1316
neck_num_heads = [4, 8, _base_.last_stage_out_channels // 2 // 32]
14-
base_lr = 2e-3
17+
base_lr = 2e-4
1518
weight_decay = 0.05
1619
train_batch_size_per_gpu = 16
17-
load_from = '/group/40034/adriancheng/notebooks/rep_models/yolo_world_v2_x_obj365v1_goldg_cc3mlite_pretrain_1280ft-14996a36_repconv.pth'
20+
load_from = '../YOLOWorld_Master/yolo_models/'
1821
persistent_workers = False
1922

23+
2024
# model settings
2125
model = dict(type='SimpleYOLOWorldDetector',
2226
mm_neck=True,
@@ -28,14 +32,12 @@
2832
type='MultiModalYOLOBackbone',
2933
text_model=None,
3034
image_model={{_base_.model.backbone}},
31-
frozen_stages=4,
3235
with_text_model=False),
3336
neck=dict(type='YOLOWorldPAFPN',
34-
guide_channels=num_classes,
37+
guide_channels=text_channels,
3538
embed_channels=neck_embed_channels,
3639
num_heads=neck_num_heads,
37-
block_cfg=dict(type='RepConvMaxSigmoidCSPLayerWithTwoConv',
38-
guide_channels=num_classes)),
40+
block_cfg=dict(type='EfficientCSPLayerWithTwoConv')),
3941
bbox_head=dict(head_module=dict(type='RepYOLOWorldHeadModule',
4042
embed_dims=text_channels,
4143
num_guide=num_classes,
@@ -135,16 +137,6 @@
135137
lr=base_lr,
136138
weight_decay=weight_decay,
137139
batch_size_per_gpu=train_batch_size_per_gpu),
138-
paramwise_cfg=dict(bias_decay_mult=0.0,
139-
norm_decay_mult=0.0,
140-
custom_keys={
141-
'backbone.text_model':
142-
dict(lr_mult=0.01),
143-
'logit_scale':
144-
dict(weight_decay=0.0),
145-
'embeddings':
146-
dict(weight_decay=0.0)
147-
}),
148140
constructor='YOLOWv5OptimizerConstructor')
149141

150142
# evaluation settings
@@ -153,4 +145,3 @@
153145
proposal_nums=(100, 1, 10),
154146
ann_file='data/coco/annotations/instances_val2017.json',
155147
metric='bbox')
156-
find_unused_parameters = True

deploy/easydeploy/model/model.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,14 @@ def __init__(self,
2828
baseModel: nn.Module,
2929
backend: MMYOLOBackend,
3030
postprocess_cfg: Optional[ConfigDict] = None,
31-
with_nms=True):
31+
with_nms=True,
32+
without_bbox_decoder=False):
3233
super().__init__()
3334
self.baseModel = baseModel
3435
self.baseHead = baseModel.bbox_head
3536
self.backend = backend
3637
self.with_nms = with_nms
38+
self.without_bbox_decoder = without_bbox_decoder
3739
if postprocess_cfg is None:
3840
self.with_postprocess = False
3941
else:
@@ -103,7 +105,8 @@ def pred_by_feat(self,
103105
bbox_decoder = yolox_bbox_decoder
104106
else:
105107
bbox_decoder = self.bbox_decoder
106-
108+
print(bbox_decoder)
109+
107110
num_imgs = cls_scores[0].shape[0]
108111
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
109112

@@ -112,7 +115,6 @@ def pred_by_feat(self,
112115
device=device)
113116

114117
flatten_priors = torch.cat(mlvl_priors)
115-
116118
mlvl_strides = [
117119
flatten_priors.new_full(
118120
(featmap_size[0] * featmap_size[1] * self.num_base_priors, ),
@@ -121,8 +123,6 @@ def pred_by_feat(self,
121123
]
122124
flatten_stride = torch.cat(mlvl_strides)
123125

124-
# flatten cls_scores, bbox_preds and objectness
125-
# using score.shape
126126
text_len = cls_scores[0].shape[1]
127127
flatten_cls_scores = [
128128
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, text_len)
@@ -145,7 +145,9 @@ def pred_by_feat(self,
145145
cls_scores = cls_scores * (flatten_objectness.unsqueeze(-1))
146146

147147
scores = cls_scores
148-
148+
bboxes = flatten_bbox_preds
149+
if self.without_bbox_decoder:
150+
return scores, bboxes
149151
bboxes = bbox_decoder(flatten_priors[None], flatten_bbox_preds,
150152
flatten_stride)
151153

deploy/export_onnx.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ def parse_args():
3737
parser.add_argument('--without-nms',
3838
action='store_true',
3939
help='Expore model without NMS')
40+
parser.add_argument('--without-bbox-decoder',
41+
action='store_true',
42+
help='Expore model without Bbox Decoder (for INT8 Quantization)')
4043
parser.add_argument('--work-dir',
4144
default='./work_dirs',
4245
help='Path to save export model')
@@ -129,7 +132,8 @@ def main():
129132
deploy_model = DeployModel(baseModel=baseModel,
130133
backend=backend,
131134
postprocess_cfg=postprocess_cfg,
132-
with_nms=not args.without_nms)
135+
with_nms=not args.without_nms,
136+
without_bbox_decoder=args.without_bbox_decoder)
133137
deploy_model.eval()
134138

135139
fake_input = torch.randn(args.batch_size, 3,

docs/reparameterize.md

+52-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
The reparameterization incorporates text embeddings as parameters into the model. For example, in the final classification layer, text embeddings are reparameterized into a simple 1x1 convolutional layer.
44

5+
<div align="center">
6+
<img width="600" src="../assets/reparameterize.png">
7+
</div>
8+
59
### Key Advantages from Reparameterization
610

711
> Reparameterized YOLO-World still has zero-shot ability!
@@ -15,13 +19,59 @@ For example, fine-tuning the **reparameterized YOLO-World** obtains *46.3 AP* on
1519

1620
#### 1. Prepare cutstom text embeddings
1721

18-
You need to generate the text embeddings
19-
22+
You need to generate the text embeddings by [`toos/generate_text_prompts.py`](../tools/generate_text_prompts.py) and save it as a `numpy.array` with shape `NxD`.
2023

2124
#### 2. Reparameterizing
2225

26+
Reparameterizing will generate a new checkpoint with text embeddings!
27+
28+
Check those files first:
29+
30+
* model checkpoint
31+
* text embeddings
2332

33+
We mainly reparameterize two groups of modules:
34+
35+
* head (`YOLOWorldHeadModule`)
36+
* neck (`MaxSigmoidCSPLayerWithTwoConv`)
37+
38+
```bash
39+
python tools/reparameterize_yoloworld.py \
40+
--model path/to/checkpoint \
41+
--out-dir path/to/save/re-parameterized/ \
42+
--text-embed path/to/text/embeddings \
43+
--conv-neck
44+
```
2445

2546

2647
#### 3. Prepare the model config
2748

49+
Please see the sample config: [`finetune_coco/yolo_world_v2_s_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py`](../configs/finetune_coco/yolo_world_v2_s_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py) for reparameterized training.
50+
51+
52+
* `RepConvMaxSigmoidCSPLayerWithTwoConv`:
53+
54+
```python
55+
neck=dict(type='YOLOWorldPAFPN',
56+
guide_channels=num_classes,
57+
embed_channels=neck_embed_channels,
58+
num_heads=neck_num_heads,
59+
block_cfg=dict(type='RepConvMaxSigmoidCSPLayerWithTwoConv',
60+
guide_channels=num_classes)),
61+
```
62+
63+
* `RepYOLOWorldHeadModule`:
64+
65+
```python
66+
bbox_head=dict(head_module=dict(type='RepYOLOWorldHeadModule',
67+
embed_dims=text_channels,
68+
num_guide=num_classes,
69+
num_classes=num_classes)),
70+
71+
```
72+
73+
#### 4. Reparameterized Training
74+
75+
**Reparameterized YOLO-World** is easier to fine-tune and can be treated as an enhanced and pre-trained YOLOv8!
76+
77+
You can check [`finetune_coco/yolo_world_v2_s_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py`](../configs/finetune_coco/yolo_world_v2_s_rep_vlpan_bn_2e-4_80e_8gpus_mask-refine_finetune_coco.py) for more details.

tools/reparameterize_yoloworld.py

+5
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ def reparameterize_head(state_dict, embeds):
4949

5050
def convert_neck_split_conv(input_state_dict, block_name, text_embeds,
5151
num_heads):
52+
if block_name + '.guide_fc.weight' not in input_state_dict:
53+
return input_state_dict
5254
guide_fc_weight = input_state_dict[block_name + '.guide_fc.weight']
5355
guide_fc_bias = input_state_dict[block_name + '.guide_fc.bias']
5456
guide = text_embeds @ guide_fc_weight.transpose(0,
@@ -77,12 +79,15 @@ def convert_neck_weight(input_state_dict, block_name, embeds, num_heads):
7779

7880

7981
def reparameterize_neck(state_dict, embeds, type='conv'):
82+
8083
neck_blocks = [
8184
'neck.top_down_layers.0.attn_block',
8285
'neck.top_down_layers.1.attn_block',
8386
'neck.bottom_up_layers.0.attn_block',
8487
'neck.bottom_up_layers.1.attn_block'
8588
]
89+
if "neck.top_down_layers.0.attn_block.bias" not in state_dict:
90+
return state_dict
8691
for block in neck_blocks:
8792
num_heads = state_dict[block + '.bias'].shape[0]
8893
if type == 'conv':

0 commit comments

Comments
 (0)