|
8 | 8 |
|
9 | 9 | from functools import partial
|
10 | 10 |
|
11 |
| -from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer |
| 11 | +from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer, TinyViT |
12 | 12 |
|
13 | 13 |
|
14 | 14 | def build_sam_vit_h(checkpoint=None):
|
@@ -44,11 +44,61 @@ def build_sam_vit_b(checkpoint=None):
|
44 | 44 | )
|
45 | 45 |
|
46 | 46 |
|
| 47 | +def build_sam_vit_t(checkpoint=None): |
| 48 | + prompt_embed_dim = 256 |
| 49 | + image_size = 1024 |
| 50 | + vit_patch_size = 16 |
| 51 | + image_embedding_size = image_size // vit_patch_size |
| 52 | + mobile_sam = Sam( |
| 53 | + image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000, |
| 54 | + embed_dims=[64, 128, 160, 320], |
| 55 | + depths=[2, 2, 6, 2], |
| 56 | + num_heads=[2, 4, 5, 10], |
| 57 | + window_sizes=[7, 7, 14, 7], |
| 58 | + mlp_ratio=4., |
| 59 | + drop_rate=0., |
| 60 | + drop_path_rate=0.0, |
| 61 | + use_checkpoint=False, |
| 62 | + mbconv_expand_ratio=4.0, |
| 63 | + local_conv_size=3, |
| 64 | + layer_lr_decay=0.8 |
| 65 | + ), |
| 66 | + prompt_encoder=PromptEncoder( |
| 67 | + embed_dim=prompt_embed_dim, |
| 68 | + image_embedding_size=(image_embedding_size, image_embedding_size), |
| 69 | + input_image_size=(image_size, image_size), |
| 70 | + mask_in_chans=16, |
| 71 | + ), |
| 72 | + mask_decoder=MaskDecoder( |
| 73 | + num_multimask_outputs=3, |
| 74 | + transformer=TwoWayTransformer( |
| 75 | + depth=2, |
| 76 | + embedding_dim=prompt_embed_dim, |
| 77 | + mlp_dim=2048, |
| 78 | + num_heads=8, |
| 79 | + ), |
| 80 | + transformer_dim=prompt_embed_dim, |
| 81 | + iou_head_depth=3, |
| 82 | + iou_head_hidden_dim=256, |
| 83 | + ), |
| 84 | + pixel_mean=[123.675, 116.28, 103.53], |
| 85 | + pixel_std=[58.395, 57.12, 57.375], |
| 86 | + ) |
| 87 | + |
| 88 | + mobile_sam.eval() |
| 89 | + if checkpoint is not None: |
| 90 | + with open(checkpoint, "rb") as f: |
| 91 | + state_dict = torch.load(f) |
| 92 | + mobile_sam.load_state_dict(state_dict) |
| 93 | + return mobile_sam |
| 94 | + |
| 95 | + |
47 | 96 | sam_model_registry = {
|
48 | 97 | "default": build_sam_vit_h,
|
49 | 98 | "vit_h": build_sam_vit_h,
|
50 | 99 | "vit_l": build_sam_vit_l,
|
51 | 100 | "vit_b": build_sam_vit_b,
|
| 101 | + "vit_t": build_sam_vit_t, |
52 | 102 | }
|
53 | 103 |
|
54 | 104 |
|
@@ -105,3 +155,5 @@ def _build_sam(
|
105 | 155 | state_dict = torch.load(f)
|
106 | 156 | sam.load_state_dict(state_dict)
|
107 | 157 | return sam
|
| 158 | + |
| 159 | + |
0 commit comments