Skip to content

Commit 1e6bdc0

Browse files
Merge pull request huggingface#63 from jinwonkim93/feat/mobileSAM
Add mobileSAM in SamDetector
2 parents 527bf56 + 486deff commit 1e6bdc0

File tree

8 files changed

+777
-6
lines changed

8 files changed

+777
-6
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ lineart = LineartDetector.from_pretrained("lllyasviel/Annotators")
6666
lineart_anime = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
6767
zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
6868
sam = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
69+
mobile_sam = SamDetector.from_pretrained("dhkim2810/MobileSAM", model_type="vit_t", filename="mobile_sam.pt")
6970
leres = LeresDetector.from_pretrained("lllyasviel/Annotators")
7071

7172
# instantiate

src/controlnet_aux/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@
1414
from .canny import CannyDetector
1515
from .mediapipe_face import MediapipeFaceDetector
1616
from .segment_anything import SamDetector
17-
from .shuffle import ContentShuffleDetector
17+
from .shuffle import ContentShuffleDetector

src/controlnet_aux/segment_anything/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(self, mask_generator: SamAutomaticMaskGenerator):
2626
@classmethod
2727
def from_pretrained(cls, pretrained_model_or_path, model_type="vit_h", filename="sam_vit_h_4b8939.pth", subfolder=None, cache_dir=None):
2828
"""
29-
Possible model_type : vit_h, vit_l, vit_b
29+
Possible model_type : vit_h, vit_l, vit_b, vit_t
3030
download weights from https://github.com/facebookresearch/segment-anything
3131
"""
3232
if os.path.isdir(pretrained_model_or_path):

src/controlnet_aux/segment_anything/build_sam.py

+53-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from functools import partial
1010

11-
from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
11+
from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer, TinyViT
1212

1313

1414
def build_sam_vit_h(checkpoint=None):
@@ -44,11 +44,61 @@ def build_sam_vit_b(checkpoint=None):
4444
)
4545

4646

47+
def build_sam_vit_t(checkpoint=None):
48+
prompt_embed_dim = 256
49+
image_size = 1024
50+
vit_patch_size = 16
51+
image_embedding_size = image_size // vit_patch_size
52+
mobile_sam = Sam(
53+
image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000,
54+
embed_dims=[64, 128, 160, 320],
55+
depths=[2, 2, 6, 2],
56+
num_heads=[2, 4, 5, 10],
57+
window_sizes=[7, 7, 14, 7],
58+
mlp_ratio=4.,
59+
drop_rate=0.,
60+
drop_path_rate=0.0,
61+
use_checkpoint=False,
62+
mbconv_expand_ratio=4.0,
63+
local_conv_size=3,
64+
layer_lr_decay=0.8
65+
),
66+
prompt_encoder=PromptEncoder(
67+
embed_dim=prompt_embed_dim,
68+
image_embedding_size=(image_embedding_size, image_embedding_size),
69+
input_image_size=(image_size, image_size),
70+
mask_in_chans=16,
71+
),
72+
mask_decoder=MaskDecoder(
73+
num_multimask_outputs=3,
74+
transformer=TwoWayTransformer(
75+
depth=2,
76+
embedding_dim=prompt_embed_dim,
77+
mlp_dim=2048,
78+
num_heads=8,
79+
),
80+
transformer_dim=prompt_embed_dim,
81+
iou_head_depth=3,
82+
iou_head_hidden_dim=256,
83+
),
84+
pixel_mean=[123.675, 116.28, 103.53],
85+
pixel_std=[58.395, 57.12, 57.375],
86+
)
87+
88+
mobile_sam.eval()
89+
if checkpoint is not None:
90+
with open(checkpoint, "rb") as f:
91+
state_dict = torch.load(f)
92+
mobile_sam.load_state_dict(state_dict)
93+
return mobile_sam
94+
95+
4796
sam_model_registry = {
4897
"default": build_sam_vit_h,
4998
"vit_h": build_sam_vit_h,
5099
"vit_l": build_sam_vit_l,
51100
"vit_b": build_sam_vit_b,
101+
"vit_t": build_sam_vit_t,
52102
}
53103

54104

@@ -105,3 +155,5 @@ def _build_sam(
105155
state_dict = torch.load(f)
106156
sam.load_state_dict(state_dict)
107157
return sam
158+
159+

src/controlnet_aux/segment_anything/modeling/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@
99
from .mask_decoder import MaskDecoder
1010
from .prompt_encoder import PromptEncoder
1111
from .transformer import TwoWayTransformer
12+
from .tiny_vit_sam import TinyViT

src/controlnet_aux/segment_anything/modeling/sam.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
from torch import nn
99
from torch.nn import functional as F
1010

11-
from typing import Any, Dict, List, Tuple
11+
from typing import Any, Dict, List, Tuple, Union
1212

13+
from .tiny_vit_sam import TinyViT
1314
from .image_encoder import ImageEncoderViT
1415
from .mask_decoder import MaskDecoder
1516
from .prompt_encoder import PromptEncoder
@@ -21,7 +22,7 @@ class Sam(nn.Module):
2122

2223
def __init__(
2324
self,
24-
image_encoder: ImageEncoderViT,
25+
image_encoder: Union[ImageEncoderViT, TinyViT],
2526
prompt_encoder: PromptEncoder,
2627
mask_decoder: MaskDecoder,
2728
pixel_mean: List[float] = [123.675, 116.28, 103.53],

0 commit comments

Comments
 (0)