forked from oneThousand1000/HairMapper
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c2e439e
commit 1b25ea4
Showing
217 changed files
with
12,292 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import os | ||
import cv2 | ||
import random | ||
import numpy as np | ||
import torch | ||
import argparse | ||
from shutil import copyfile | ||
from .src.config import Config | ||
from .src.classifier import Classifier | ||
|
||
def get_model(mode=None,attribuite='hair'): | ||
config = load_config(mode) | ||
|
||
# cuda visble devices | ||
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(str(e) for e in config.GPU) | ||
|
||
|
||
# init device | ||
if torch.cuda.is_available(): | ||
config.DEVICE = torch.device("cuda") | ||
torch.backends.cudnn.benchmark = True # cudnn auto-tuner | ||
else: | ||
config.DEVICE = torch.device("cpu") | ||
|
||
|
||
|
||
cv2.setNumThreads(0) | ||
|
||
|
||
# initialize random seed | ||
torch.manual_seed(config.SEED) | ||
torch.cuda.manual_seed_all(config.SEED) | ||
np.random.seed(config.SEED) | ||
random.seed(config.SEED) | ||
|
||
|
||
|
||
# build the model and initialize | ||
if attribuite=='hair': | ||
model = Classifier(config,'hair_classification') | ||
elif attribuite=='gender': | ||
model = Classifier(config,'gender_classification') | ||
elif attribuite=='smile': | ||
model = Classifier(config,'smile_classification') | ||
model.load() | ||
return model | ||
|
||
|
||
|
||
def load_config(mode=None): | ||
r"""loads model config | ||
Args: | ||
mode (int): 1: train, 2: test, 3: eval, reads from config file if not specified | ||
""" | ||
|
||
|
||
|
||
|
||
|
||
config_path = os.path.join(os.path.dirname(__file__),'./config.yml') | ||
|
||
# create checkpoints path if does't exist | ||
if not os.path.exists('./checkpoints'): | ||
os.makedirs('./checkpoints') | ||
|
||
# copy config template if does't exist | ||
if not os.path.exists(config_path): | ||
copyfile('./config.yml.example', config_path) | ||
|
||
# load config file | ||
config = Config(config_path) | ||
|
||
|
||
|
||
return config | ||
|
||
def check_hair(img,model): | ||
output= model.process(img) | ||
return output[0][1]>0.09 | ||
|
||
|
||
def check_gender(img,model): | ||
output= model.process(img) | ||
return output[0][1]>0.15 | ||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
MODE: 1 # 1: train, 2: test, 3: eval | ||
MODEL: 1 # 1: edge model, 2: inpaint model, 3: edge-inpaint model, 4: joint model | ||
EDGE: 1 # 1: canny, 2: external | ||
NMS: 1 # 0: no non-max-suppression, 1: applies non-max-suppression on the external edges by multiplying by Canny | ||
SEED: 10 # random seed | ||
GPU: [0] # list of gpu ids | ||
DEBUG: 0 # turns on debugging mode | ||
VERBOSE: 0 # turns on verbose mode in the output console | ||
|
||
LR: 0.0001 # learning rate | ||
D2G_LR: 0.1 # discriminator/generator learning rate ratio | ||
BETA1: 0.0 # adam optimizer beta1 | ||
BETA2: 0.9 # adam optimizer beta2 | ||
BATCH_SIZE: 4 # input batch size for training | ||
INPUT_SIZE: 256 # input image size for training 0 for original size | ||
SIGMA: 2 # standard deviation of the Gaussian filter used in Canny edge detector (0: random, -1: no edge) | ||
MAX_ITERS: 2e6 # maximum number of iterations to train the model | ||
|
||
EDGE_THRESHOLD: 0.5 # edge detection threshold | ||
L1_LOSS_WEIGHT: 1 # l1 loss weight | ||
FM_LOSS_WEIGHT: 10 # feature-matching loss weight | ||
STYLE_LOSS_WEIGHT: 250 # style loss weight | ||
CONTENT_LOSS_WEIGHT: 0.1 # perceptual loss weight | ||
INPAINT_ADV_LOSS_WEIGHT: 0.1 # adversarial loss weight | ||
|
||
GAN_LOSS: nsgan # nsgan | lsgan | hinge | ||
GAN_POOL_SIZE: 0 # fake images pool size | ||
|
||
SAVE_INTERVAL: 1000 # how many iterations to wait before saving model (0: never) | ||
SAMPLE_INTERVAL: 400 # how many iterations to wait before sampling (0: never) | ||
SAMPLE_SIZE: 4 # number of images to sample | ||
EVAL_INTERVAL: 0 # how many iterations to wait before model evaluation (0: never) | ||
LOG_INTERVAL: 10 # how many iterations to wait before logging training status (0: never) | ||
|
||
num_cpus_per_job: 4 | ||
num_gpus_per_job: 1 | ||
log_dir: './checkpoints/' | ||
model_dir: './checkpoints/model/stock2.model' | ||
model_load_dir: './checkpoints/model/stock2.model' | ||
val: True | ||
|
||
ori_train: './gender_data/train/origin_img.flist' | ||
ori_val: './gender_data/val/origin_img.flist' | ||
|
||
use_mask: True |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
import os | ||
import torch | ||
from torch.utils.data import DataLoader | ||
from .dataset import Dataset | ||
from .models import ClassificationModel | ||
from .utils import Progbar, create_dir | ||
|
||
class Classifier(): | ||
def __init__(self, config,name): | ||
self.config = config | ||
|
||
model_name = name | ||
|
||
self.debug = False | ||
self.model_name = model_name | ||
self.classification_model = ClassificationModel(model_name,'classification_model',config).to(config.DEVICE) | ||
|
||
# test mode | ||
if self.config.MODE == 2: | ||
self.test_dataset =Dataset(config,config.ori_val, augment=False, training=True) | ||
else: | ||
self.train_dataset = Dataset(config, config.ori_train, augment=True, training=True) | ||
self.val_dataset = Dataset(config, config.ori_val, augment=False, training=True) | ||
self.sampler=Dataset(config, config.ori_val,augment=False, training=True) | ||
self.results_path = os.path.join(config.PATH, 'results') | ||
|
||
|
||
if config.DEBUG is not None and config.DEBUG != 0: | ||
self.debug = True | ||
create_dir(os.path.join(config.PATH,model_name)) | ||
self.log_file = os.path.join(os.path.join(config.PATH,model_name), 'log_' + model_name + '.dat') | ||
|
||
def load(self): | ||
self.classification_model.load() | ||
|
||
def save(self): | ||
self.classification_model.save() | ||
|
||
def train(self): | ||
train_loader = DataLoader( | ||
dataset=self.train_dataset, | ||
batch_size=self.config.BATCH_SIZE, | ||
num_workers=4, | ||
drop_last=True, | ||
shuffle=True | ||
) | ||
|
||
epoch = 0 | ||
keep_training = True | ||
max_iteration = int(float((self.config.MAX_ITERS))) | ||
total = len(self.train_dataset) | ||
if total == 0: | ||
print('No training data was provided! Check \'TRAIN_FLIST\' value in the configuration file.') | ||
return | ||
|
||
while(keep_training): | ||
epoch += 1 | ||
print('\n\nTraining epoch: %d' % epoch) | ||
|
||
progbar = Progbar(total, width=20, stateful_metrics=['epoch', 'iter']) | ||
|
||
for items in train_loader: | ||
self.classification_model.train() | ||
|
||
images,labels= self.cuda(*items) | ||
|
||
|
||
outputs, loss, logs,precision = self.classification_model.process(images,labels) | ||
#print(outputs) | ||
|
||
|
||
# backward | ||
self.classification_model.backward(loss) | ||
iteration = self.classification_model.iteration | ||
|
||
|
||
|
||
|
||
|
||
if iteration >= max_iteration: | ||
keep_training = False | ||
break | ||
|
||
logs = [ | ||
("epoch", epoch), | ||
("iter", iteration), | ||
] + logs | ||
progbar.add(len(images), | ||
values=logs if self.config.VERBOSE else [x for x in logs if not x[0].startswith('l_')]) | ||
|
||
# log model at checkpoints | ||
if self.config.LOG_INTERVAL and iteration % self.config.LOG_INTERVAL == 0: | ||
self.log(logs) | ||
|
||
# sample model at checkpoints | ||
|
||
|
||
# evaluate model at checkpoints | ||
if self.config.EVAL_INTERVAL and iteration % self.config.EVAL_INTERVAL == 0: | ||
print('\nstart eval...\n') | ||
self.eval() | ||
|
||
# save model at checkpoints | ||
if self.config.SAVE_INTERVAL and iteration % self.config.SAVE_INTERVAL == 0: | ||
self.save() | ||
|
||
print('\nEnd training....') | ||
|
||
def eval(self): | ||
val_loader = DataLoader( | ||
dataset=self.val_dataset, | ||
batch_size=self.config.BATCH_SIZE, | ||
drop_last=True, | ||
shuffle=True | ||
) | ||
|
||
total = len(self.val_dataset) | ||
|
||
self.classification_model.eval() | ||
|
||
progbar = Progbar(total, width=20, stateful_metrics=['it']) | ||
iteration = 0 | ||
|
||
for items in val_loader: | ||
iteration += 1 | ||
images, labels = self.cuda(*items) | ||
|
||
outputs, loss, logs, precision = self.classification_model.process(images, labels) | ||
|
||
|
||
|
||
logs = [("it", iteration), ] + logs | ||
progbar.add(len(images), values=logs) | ||
|
||
|
||
|
||
def log(self, logs): | ||
with open(self.log_file, 'a') as f: | ||
f.write('%s\n' % ' '.join([str(item[1]) for item in logs])) | ||
|
||
def cuda(self, *args): | ||
return (item.to(self.config.DEVICE) for item in args) | ||
|
||
def postprocess(self, img): | ||
# [0, 1] => [0, 255] | ||
img = img * 255.0 | ||
img = img.permute(0, 2, 3, 1) | ||
return img.int() | ||
|
||
|
||
def process(self,img): | ||
|
||
with torch.no_grad(): | ||
img=self.sampler.generate_test_data(img) | ||
outputs = self.classification_model(img) | ||
outputs=outputs.cpu().numpy() | ||
#print(outputs) | ||
return outputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import os | ||
import yaml | ||
|
||
class Config(dict): | ||
def __init__(self, config_path): | ||
with open(config_path, 'r') as f: | ||
self._yaml = f.read() | ||
self._dict = yaml.load(self._yaml) | ||
self._dict['PATH'] = os.path.dirname(config_path) | ||
|
||
def __getattr__(self, name): | ||
if self._dict.get(name) is not None: | ||
return self._dict[name] | ||
|
||
if DEFAULT_CONFIG.get(name) is not None: | ||
return DEFAULT_CONFIG[name] | ||
|
||
return None | ||
|
||
def print(self): | ||
print('Model configurations:') | ||
print('---------------------------------') | ||
print(self._yaml) | ||
print('') | ||
print('---------------------------------') | ||
print('') | ||
|
||
|
||
DEFAULT_CONFIG = { | ||
'MODE': 1, # 1: train, 2: test, 3: eval | ||
'NMS': 1, # 0: no non-max-suppression, 1: applies non-max-suppression on the external edges by multiplying by Canny | ||
'SEED': 10, # random seed | ||
'GPU': [0], # list of gpu ids | ||
'DEBUG': 0, # turns on debugging mode | ||
'VERBOSE': 0, # turns on verbose mode in the output console | ||
|
||
'LR': 0.0001, # learning rate | ||
'BETA1': 0.0, # adam optimizer beta1 | ||
'BETA2': 0.9, # adam optimizer beta2 | ||
'BATCH_SIZE': 4, # input batch size for training | ||
'INPUT_SIZE': 256, # input image size for training 0 for original size | ||
'SIGMA': 2, # standard deviation of the Gaussian filter used in Canny edge detector (0: random, -1: no edge) | ||
'MAX_ITERS': 2e6, # maximum number of iterations to train the model | ||
|
||
'EDGE_THRESHOLD': 0.5, # edge detection threshold | ||
'L1_LOSS_WEIGHT': 1, # l1 loss weight | ||
'FM_LOSS_WEIGHT': 10, # feature-matching loss weight | ||
'STYLE_LOSS_WEIGHT': 1, # style loss weight | ||
'CONTENT_LOSS_WEIGHT': 1, # perceptual loss weight | ||
'INPAINT_ADV_LOSS_WEIGHT': 0.01,# adversarial loss weight | ||
|
||
'GAN_LOSS': 'nsgan', # nsgan | lsgan | hinge | ||
'GAN_POOL_SIZE': 0, # fake images pool size | ||
|
||
'SAVE_INTERVAL': 1000, # how many iterations to wait before saving model (0: never) | ||
'SAMPLE_INTERVAL': 1000, # how many iterations to wait before sampling (0: never) | ||
'SAMPLE_SIZE': 12, # number of images to sample | ||
'EVAL_INTERVAL': 0, # how many iterations to wait before model evaluation (0: never) | ||
'LOG_INTERVAL': 10, # how many iterations to wait before logging training status (0: never) | ||
|
||
'num_cpus_per_job': 4, | ||
'num_gpus_per_job': 1, | ||
'log_dir': './checkpoints/', | ||
'model_dir': './checkpoints/model/stock2.model', | ||
'model_load_dir': './checkpoints/model/stock2.model', | ||
'val': True, | ||
|
||
'ori_train': './data/train/origin_img.flist', | ||
'ori_val': './data/val/origin_img.flist', | ||
'mask_train': './data/train/mask.flist', | ||
'mask_val': './data/val/mask.flist', | ||
'use_mask': True | ||
|
||
|
||
} |
Oops, something went wrong.