Skip to content

Commit

Permalink
Added support for non-RGB images
Browse files Browse the repository at this point in the history
  • Loading branch information
lukemelas committed Oct 12, 2019
1 parent b6a1be9 commit 3676f4e
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions efficientnet_pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,21 @@ def from_name(cls, model_name, override_params=None):
blocks_args, global_params = get_model_params(model_name, override_params)
return cls(blocks_args, global_params)

@classmethod
def from_pretrained(cls, model_name, num_classes=1000, in_channels = 3):
model = cls.from_name(model_name, override_params={'num_classes': num_classes})
load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000))
if in_channels != 3:
Conv2d = get_same_padding_conv2d(image_size = model._global_params.image_size)
out_channels = round_filters(32, model._global_params)
model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
return model

@classmethod
def from_pretrained(cls, model_name, num_classes=1000):
model = cls.from_name(model_name, override_params={'num_classes': num_classes})
load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000))

return model

@classmethod
Expand Down

0 comments on commit 3676f4e

Please sign in to comment.