Skip to content

Commit

Permalink
Merge branch 'dev' into mask-debug
Browse files Browse the repository at this point in the history
  • Loading branch information
hellock committed Oct 5, 2018
2 parents 98b20b9 + 5266dea commit bb6ef3b
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 9 deletions.
12 changes: 12 additions & 0 deletions mmdet/models/detectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,18 @@ class BaseDetector(nn.Module):
def __init__(self):
super(BaseDetector, self).__init__()

@property
def with_neck(self):
return hasattr(self, 'neck') and self.neck is not None

@property
def with_bbox(self):
return hasattr(self, 'bbox_head') and self.bbox_head is not None

@property
def with_mask(self):
return hasattr(self, 'mask_head') and self.mask_head is not None

@abstractmethod
def extract_feat(self, imgs):
pass
Expand Down
4 changes: 2 additions & 2 deletions mmdet/models/detectors/rpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ def __init__(self,
def init_weights(self, pretrained=None):
super(RPN, self).init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
if self.neck is not None:
if self.with_neck:
self.neck.init_weights()
self.rpn_head.init_weights()

def extract_feat(self, img):
x = self.backbone(img)
if self.neck is not None:
if self.with_neck:
x = self.neck(x)
return x

Expand Down
14 changes: 7 additions & 7 deletions mmdet/models/detectors/two_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,19 @@ def __init__(self,
self.backbone = builder.build_backbone(backbone)

if neck is not None:
self.with_neck = True
self.neck = builder.build_neck(neck)
else:
raise NotImplementedError

self.with_rpn = True if rpn_head is not None else False
if self.with_rpn:
if rpn_head is not None:
self.rpn_head = builder.build_rpn_head(rpn_head)

self.with_bbox = True if bbox_head is not None else False
if self.with_bbox:
if bbox_head is not None:
self.bbox_roi_extractor = builder.build_roi_extractor(
bbox_roi_extractor)
self.bbox_head = builder.build_bbox_head(bbox_head)

self.with_mask = True if mask_head is not None else False
if self.with_mask:
if mask_head is not None:
self.mask_roi_extractor = builder.build_roi_extractor(
mask_roi_extractor)
self.mask_head = builder.build_mask_head(mask_head)
Expand All @@ -51,6 +47,10 @@ def __init__(self,

self.init_weights(pretrained=pretrained)

@property
def with_rpn(self):
return hasattr(self, 'rpn_head') and self.rpn_head is not None

def init_weights(self, pretrained=None):
super(TwoStageDetector, self).init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
Expand Down

0 comments on commit bb6ef3b

Please sign in to comment.