-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit ccbad24
Showing
13 changed files
with
680 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# Pycharm | ||
.idea | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
.hypothesis/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
.static_storage/ | ||
.media/ | ||
local_settings.py | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# celery beat schedule file | ||
celerybeat-schedule | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# nasnet.pytorch | ||
|
||
## Work In Progress | ||
|
||
Pytorch implementation of [Learning Transferable Architectures for Scalable Image Recognition](https://arxiv.org/abs/1707.07012). Currently crashes in Pytorch 0.2.0, should work with Pytorch@master or Pytorch 0.3.0. | ||
|
||
## TODO | ||
|
||
* Clean up code | ||
* Refactor the quick and dirty PowersignCD | ||
* Pretrain nets on ImageNet | ||
* Write a better path dropout |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
|
||
import torch.nn.parallel | ||
import torch.optim | ||
import torch.utils.data | ||
|
||
from torch.autograd import Variable | ||
|
||
from nasnet import NASNet, nasnetmobile, nasnetlarge, PowersignCD | ||
|
||
import os | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
from torchvision import datasets, transforms | ||
from tqdm import tqdm | ||
|
||
class Trainer(object): | ||
cuda = torch.cuda.is_available() | ||
|
||
def __init__(self, model, optimizer, loss_f, save_dir=None, save_freq=5): | ||
self.model = model | ||
if self.cuda: | ||
model.cuda() | ||
self.optimizer = optimizer | ||
self.loss_f = loss_f | ||
self.save_dir = save_dir | ||
self.save_freq = save_freq | ||
|
||
def _loop(self, data_loader, is_train=True): | ||
loop_loss = [] | ||
correct = [] | ||
for data, target in tqdm(data_loader): | ||
if self.cuda: | ||
data, target = data.cuda(), target.cuda() | ||
data, target = Variable(data, volatile=not is_train), Variable(target, volatile=not is_train) | ||
self.optimizer.zero_grad() | ||
output = self.model(data) | ||
loss = self.loss_f(output, target) | ||
loop_loss.append(loss.data[0] / len(data_loader)) | ||
correct.append((output.data.max(1)[1] == target.data).sum() / len(data_loader.dataset)) | ||
if is_train: | ||
loss.backward() | ||
self.optimizer.step() | ||
mode = "train" if is_train else "test" | ||
print(f">>>[{mode}] loss: {sum(loop_loss):.2f}/accuracy: {sum(correct):.2%}") | ||
return loop_loss, correct | ||
|
||
def train(self, data_loader): | ||
self.model.train() | ||
loss, correct = self._loop(data_loader) | ||
|
||
def test(self, data_loader): | ||
self.model.eval() | ||
loss, correct = self._loop(data_loader, is_train=False) | ||
|
||
def loop(self, epochs, train_data, test_data, scheduler=None): | ||
for ep in range(1, epochs + 1): | ||
if scheduler is not None: | ||
scheduler.step() | ||
print(f"epochs: {ep}") | ||
self.train(train_data) | ||
self.test(test_data) | ||
if ep % self.save_freq: | ||
self.save(ep) | ||
|
||
def save(self, epoch, **kwargs): | ||
if self.save_dir: | ||
name = f"weight-{epoch}-" + "-".join([f"{k}_{v}" for k, v in kwargs.items()]) + ".pkl" | ||
torch.save({"weight": self.model.state_dict(), | ||
"optimizer": self.optimizer.state_dict()}, | ||
os.path.join(self.save_dir, name)) | ||
|
||
def main(type, batch_size, data_root, n_epochs): | ||
if type == 'mobile': | ||
input_size = 224, | ||
model = nasnetmobile | ||
elif type == 'large': | ||
input_size = 331 | ||
model = nasnetlarge | ||
else: | ||
input_size = 299 | ||
model = nasnetlarge | ||
|
||
transform_train = transforms.Compose([ | ||
transforms.RandomSizedCrop(input_size), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize(mean=[0.485, 0.456, 0.406], | ||
std=[0.229, 0.224, 0.225]), | ||
]) | ||
transform_test = transforms.Compose([ | ||
transforms.CenterCrop(input_size), | ||
transforms.ToTensor(), | ||
transforms.Normalize(mean=[0.485, 0.456, 0.406], | ||
std=[0.229, 0.224, 0.225]), | ||
]) | ||
|
||
traindir = os.path.join(data_root, 'train') | ||
valdir = os.path.join(data_root, 'val') | ||
train = datasets.ImageFolder(traindir, transform_train) | ||
val = datasets.ImageFolder(valdir, transform_test) | ||
train_loader = torch.utils.data.DataLoader( | ||
train, batch_size=batch_size, shuffle=True, num_workers=8) | ||
test_loader = torch.utils.data.DataLoader( | ||
val, batch_size=batch_size, shuffle=True, num_workers=8) | ||
net = model(num_classes=1000) | ||
optimizer = PowersignCD(params=net.parameters(), steps=len(train)/batch_size*n_epochs, lr=0.6, momentum=0.9) | ||
trainer = Trainer(net, optimizer, F.cross_entropy, save_dir=".") | ||
trainer.loop(n_epochs, train_loader, test_loader) | ||
|
||
|
||
if __name__ == '__main__': | ||
import argparse | ||
|
||
p = argparse.ArgumentParser() | ||
p.add_argument("root", help="imagenet data root") | ||
p.add_argument("--batch_size", default=8, type=int) | ||
p.add_argument("--n_epochs", default=10, type=int) | ||
args = p.parse_args() | ||
main(args.batch_size, args.root, args.n_epochs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .nasnet import * | ||
from .optimizer import PowersignCD |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import torch | ||
import torch.cuda | ||
import torch.nn as nn | ||
import torch.functional as F | ||
from random import random | ||
from torch.autograd import Variable | ||
|
||
# Currently there is a risk of dropping all paths... | ||
# We should create a version that take all paths into account to make sure one stays alive | ||
# But then keep_prob is meaningless and we have to copute/keep track of the conditional probability | ||
class DropPath(nn.Module): | ||
def __init__(self, module, keep_prob=0.9): | ||
super(DropPath, self).__init__() | ||
self.module = module | ||
self.keep_prob = keep_prob | ||
self.shape = None | ||
self.training = True | ||
self.dtype = torch.FloatTensor | ||
|
||
def forward(self, *input): | ||
if self.training: | ||
# If we don't now the shape we run the forward path once and store the output shape | ||
if self.shape is None: | ||
temp = self.module(*input) | ||
self.shape = temp.size() | ||
if temp.data.is_cuda: | ||
self.dtype = torch.cuda.FloatTensor | ||
del temp | ||
p = random() | ||
if p <= self.keep_prob: | ||
return Variable(self.dtype(self.shape).zero_()) | ||
else: | ||
return self.module(*input)/self.keep_prob # Inverted scaling | ||
else: | ||
return self.module(*input) |
Oops, something went wrong.