Skip to content

Commit

Permalink
Resolve conflicts.
Browse files Browse the repository at this point in the history
  • Loading branch information
shijianjian committed Apr 13, 2020
2 parents eab8ead + d8481a5 commit a90b355
Show file tree
Hide file tree
Showing 16 changed files with 1,312 additions and 209 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Custom
tmp

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,28 @@ Mostly, we can save 4 times GPU memories by reducing the stride from the first b

Below is the README from the original repo:

_IMPORTANT NOTE_: In the latest update, I switched hosting providers for the pretrained models, as the previous models were becoming extremely expensive to host. This _will_ break old versions of the library. I apologize, but I cannot afford to keep serving the models on the old provider. Everything should work properly if you update the library:
```
pip install --upgrade efficientnet-pytorch
```

### Update (January 23, 2020)

This update adds a new category of pre-trained model based on adversarial training, called _advprop_. It is important to note that the preprocessing required for the advprop pretrained models is slightly different from normal ImageNet preprocessing. As a result, by default, advprop models are not used. To load a model with advprop, use:
```
model = EfficientNet.from_pretrained("efficientnet-b0", advprop=True)
```
There is also a new, large `efficientnet-b8` pretrained model that is only available in advprop form. When using these models, replace ImageNet preprocessing code as follows:
```
if advprop: # for models using advprop pretrained weights
normalize = transforms.Lambda(lambda img: img * 2.0 - 1.0)
else:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
```
This update also addresses multiple other issues ([#115](https://github.com/lukemelas/EfficientNet-PyTorch/issues/115), [#128](https://github.com/lukemelas/EfficientNet-PyTorch/issues/128)).

### Update (October 15, 2019)

This update allows you to choose whether to use a memory-efficient Swish activation. The memory-efficient version is chosen by default, but it cannot be used when exporting using PyTorch JIT. For this purpose, we have also included a standard (export-friendly) swish activation function. To switch to the export-friendly version, simply call `model.set_swish(memory_efficient=False)` after loading your desired model. This update addresses issues [#88](https://github.com/lukemelas/EfficientNet-PyTorch/pull/88) and [#89](https://github.com/lukemelas/EfficientNet-PyTorch/pull/89).
Expand Down
2 changes: 1 addition & 1 deletion efficientnet_pytorch_3d/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.5.1"
__version__ = "0.6.3"
from .model import EfficientNet3D
from .utils import (
GlobalParams,
Expand Down
9 changes: 3 additions & 6 deletions efficientnet_pytorch_3d/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
get_same_padding_conv3d,
get_model_params,
efficientnet_params,
load_pretrained_weights,
Swish,
MemoryEfficientSwish,
)
Expand Down Expand Up @@ -211,10 +210,8 @@ def get_image_size(cls, model_name):
return res

@classmethod
def _check_model_name_is_valid(cls, model_name, also_need_pretrained_weights=False):
""" Validates model name. None that pretrained weights are only available for
the first four models (efficientnet-b{i} for i in 0,1,2,3) at the moment. """
num_models = 4 if also_need_pretrained_weights else 8
valid_models = ['efficientnet-b'+str(i) for i in range(num_models)]
def _check_model_name_is_valid(cls, model_name):
""" Validates model name. """
valid_models = ['efficientnet-b'+str(i) for i in range(9)]
if model_name not in valid_models:
raise ValueError('model_name should be one of: ' + ', '.join(valid_models))
27 changes: 2 additions & 25 deletions efficientnet_pytorch_3d/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def efficientnet_params(model_name):
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
'efficientnet-b8': (2.2, 3.6, 672, 0.5),
'efficientnet-l2': (4.3, 5.3, 800, 0.5),
}
return params_dict[model_name]

Expand Down Expand Up @@ -293,28 +295,3 @@ def get_model_params(model_name, override_params):
# ValueError will be raised here if override_params has fields not included in global_params.
global_params = global_params._replace(**override_params)
return blocks_args, global_params


url_map = {
'efficientnet-b0': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b0-355c32eb.pth',
'efficientnet-b1': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b1-f1951068.pth',
'efficientnet-b2': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b2-8bb594d6.pth',
'efficientnet-b3': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b3-5fb5a3c3.pth',
'efficientnet-b4': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b4-6ed6700e.pth',
'efficientnet-b5': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b5-b6417697.pth',
'efficientnet-b6': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b6-c76e70fd.pth',
'efficientnet-b7': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b7-dcc49843.pth',
}


def load_pretrained_weights(model, model_name, load_fc=True):
""" Loads pretrained weights, and downloads if loading for the first time. """
state_dict = model_zoo.load_url(url_map[model_name])
if load_fc:
model.load_state_dict(state_dict)
else:
state_dict.pop('_fc.weight')
state_dict.pop('_fc.bias')
res = model.load_state_dict(state_dict, strict=False)
assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
print('Loaded pretrained weights for {}'.format(model_name))
10 changes: 6 additions & 4 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import torch
from torchsummary import summary

device = 'cpu'

model = EfficientNet3D.from_name("efficientnet-b0", override_params={'num_classes': 2}, in_channels=1)

summary(model, input_size=(1, 200, 1024, 200))
summary(model, input_size=(1, 224, 224, 224))

model = model.to("cuda:3")
inputs = torch.randn((1, 1, 200, 1024, 200)).to("cuda:3")
labels = torch.tensor([0]).to("cuda:3")
model = model.to(device)
inputs = torch.randn((1, 1, 224, 224, 224)).to(device)
labels = torch.tensor([0]).to(device)
# test forward
num_classes = 2

Expand Down
Loading

0 comments on commit a90b355

Please sign in to comment.