Skip to content

Commit

Permalink
Merge pull request lukemelas#134 from lukemelas/advprop
Browse files Browse the repository at this point in the history
Add advprop and b8 model
  • Loading branch information
lukemelas authored Jan 24, 2020
2 parents dbb58b1 + 396b06b commit 50b0de9
Show file tree
Hide file tree
Showing 15 changed files with 894 additions and 201 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Custom
tmp

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,28 @@
# EfficientNet PyTorch


_IMPORTANT NOTE_: In the latest update, I switched hosting providers for the pretrained models, as the previous models were becoming extremely expensive to host. This _will_ break old versions of the library. I apologize, but I cannot afford to keep serving the models on the old provider. Everything should work properly if you update the library:
```
pip install --upgrade efficientnet-pytorch
```

### Update (January 23, 2020)

This update adds a new category of pre-trained model based on adversarial training, called _advprop_. It is important to note that the preprocessing required for the advprop pretrained models is slightly different from normal ImageNet preprocessing. As a result, by default, advprop models are not used. To load a model with advprop, use:
```
model = EfficientNet.from_pretrained("efficientnet-b0", advprop=True)
```
There is also a new, large `efficientnet-b8` pretrained model that is only available in advprop form. When using these models, replace ImageNet preprocessing code as follows:
```
if advprop: # for models using advprop pretrained weights
normalize = transforms.Lambda(lambda img: img * 2.0 - 1.0)
else:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
```
This update also addresses multiple other issues ([#115](https://github.com/lukemelas/EfficientNet-PyTorch/issues/115), [#128](https://github.com/lukemelas/EfficientNet-PyTorch/issues/128)).

### Update (October 15, 2019)

This update allows you to choose whether to use a memory-efficient Swish activation. The memory-efficient version is chosen by default, but it cannot be used when exporting using PyTorch JIT. For this purpose, we have also included a standard (export-friendly) swish activation function. To switch to the export-friendly version, simply call `model.set_swish(memory_efficient=False)` after loading your desired model. This update addresses issues [#88](https://github.com/lukemelas/EfficientNet-PyTorch/pull/88) and [#89](https://github.com/lukemelas/EfficientNet-PyTorch/pull/89).
Expand Down
2 changes: 1 addition & 1 deletion efficientnet_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.5.1"
__version__ = "0.6.0"
from .model import EfficientNet
from .utils import (
GlobalParams,
Expand Down
19 changes: 5 additions & 14 deletions efficientnet_pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,33 +206,24 @@ def from_name(cls, model_name, override_params=None):
return cls(blocks_args, global_params)

@classmethod
def from_pretrained(cls, model_name, num_classes=1000, in_channels = 3):
def from_pretrained(cls, model_name, advprop=False, num_classes=1000, in_channels=3):
model = cls.from_name(model_name, override_params={'num_classes': num_classes})
load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000))
load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000), advprop=advprop)
if in_channels != 3:
Conv2d = get_same_padding_conv2d(image_size = model._global_params.image_size)
out_channels = round_filters(32, model._global_params)
model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
return model

@classmethod
def from_pretrained(cls, model_name, num_classes=1000):
model = cls.from_name(model_name, override_params={'num_classes': num_classes})
load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000))

return model

@classmethod
def get_image_size(cls, model_name):
cls._check_model_name_is_valid(model_name)
_, _, res, _ = efficientnet_params(model_name)
return res

@classmethod
def _check_model_name_is_valid(cls, model_name, also_need_pretrained_weights=False):
""" Validates model name. None that pretrained weights are only available for
the first four models (efficientnet-b{i} for i in 0,1,2,3) at the moment. """
num_models = 4 if also_need_pretrained_weights else 8
valid_models = ['efficientnet-b'+str(i) for i in range(num_models)]
def _check_model_name_is_valid(cls, model_name):
""" Validates model name. """
valid_models = ['efficientnet-b'+str(i) for i in range(9)]
if model_name not in valid_models:
raise ValueError('model_name should be one of: ' + ', '.join(valid_models))
37 changes: 27 additions & 10 deletions efficientnet_pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ def efficientnet_params(model_name):
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
'efficientnet-b8': (2.2, 3.6, 672, 0.5),
'efficientnet-l2': (4.3, 5.3, 800, 0.5),
}
return params_dict[model_name]

Expand Down Expand Up @@ -293,20 +295,35 @@ def get_model_params(model_name, override_params):
return blocks_args, global_params


url_map = {
'efficientnet-b0': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b0-355c32eb.pth',
'efficientnet-b1': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b1-f1951068.pth',
'efficientnet-b2': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b2-8bb594d6.pth',
'efficientnet-b3': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b3-5fb5a3c3.pth',
'efficientnet-b4': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b4-6ed6700e.pth',
'efficientnet-b5': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b5-b6417697.pth',
'efficientnet-b6': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b6-c76e70fd.pth',
'efficientnet-b7': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b7-dcc49843.pth',
url_map_aa = {
'efficientnet-b0': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b0-355c32eb.pth',
'efficientnet-b1': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b1-f1951068.pth',
'efficientnet-b2': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b2-8bb594d6.pth',
'efficientnet-b3': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b3-5fb5a3c3.pth',
'efficientnet-b4': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b4-6ed6700e.pth',
'efficientnet-b5': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b5-b6417697.pth',
'efficientnet-b6': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b6-c76e70fd.pth',
'efficientnet-b7': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b7-dcc49843.pth',
}


def load_pretrained_weights(model, model_name, load_fc=True):
url_map_advprop = {
'efficientnet-b0': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b0-b64d5a18.pth',
'efficientnet-b1': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b1-0f3ce85a.pth',
'efficientnet-b2': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b2-6e9d97e5.pth',
'efficientnet-b3': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b3-cdd7c0f4.pth',
'efficientnet-b4': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b4-44fb3a87.pth',
'efficientnet-b5': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b5-86493f6b.pth',
'efficientnet-b6': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b6-ac80338e.pth',
'efficientnet-b7': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b7-4652b6dd.pth',
'efficientnet-b8': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b8-22a8fe65.pth',
}


def load_pretrained_weights(model, model_name, load_fc=True, advprop=False):
""" Loads pretrained weights, and downloads if loading for the first time. """
# AutoAugment or Advprop (different preprocessing)
url_map = url_map_advprop if advprop else url_map_aa
state_dict = model_zoo.load_url(url_map[model_name])
if load_fc:
model.load_state_dict(state_dict)
Expand Down
11 changes: 8 additions & 3 deletions examples/imagenet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@
help='GPU id to use.')
parser.add_argument('--image_size', default=224, type=int,
help='image size')
parser.add_argument('--advprop', default=False, action='store_true',
help='use advprop or not')
parser.add_argument('--multiprocessing-distributed', action='store_true',
help='Use multi-processing distributed training to launch '
'N processes per node, which has N GPUs. This is the '
Expand Down Expand Up @@ -134,7 +136,7 @@ def main_worker(gpu, ngpus_per_node, args):
# create model
if 'efficientnet' in args.arch: # NEW
if args.pretrained:
model = EfficientNet.from_pretrained(args.arch)
model = EfficientNet.from_pretrained(args.arch, advprop=args.advprop)
print("=> using pre-trained model '{}'".format(args.arch))
else:
print("=> creating model '{}'".format(args.arch))
Expand Down Expand Up @@ -206,8 +208,11 @@ def main_worker(gpu, ngpus_per_node, args):
# Data loading code
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
if args.advprop:
normalize = transforms.Lambda(lambda img: img * 2.0 - 1.0)
else:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])

train_dataset = datasets.ImageFolder(
traindir,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
EMAIL = '[email protected]'
AUTHOR = 'Luke'
REQUIRES_PYTHON = '>=3.5.0'
VERSION = '0.5.1'
VERSION = '0.6.0'

# What packages are required for this module to be executed?
REQUIRED = [
Expand Down
2 changes: 1 addition & 1 deletion tf_to_pytorch/convert_tf_to_pt/load_tf_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def load_and_save_temporary_tensorflow_model(model_name, model_ckpt, example_img
parser = argparse.ArgumentParser(
description='Convert TF model to PyTorch model and save for easier future loading')
parser.add_argument('--model_name', type=str, default='efficientnet-b0',
help='efficientnet-b{N}, where N is an integer 0 <= N <= 7')
help='efficientnet-b{N}, where N is an integer 0 <= N <= 8')
parser.add_argument('--tf_checkpoint', type=str, default='pretrained_tensorflow/efficientnet-b0/',
help='checkpoint file path')
parser.add_argument('--output_file', type=str, default='pretrained_pytorch/efficientnet-b0.pth',
Expand Down
Loading

0 comments on commit 50b0de9

Please sign in to comment.