-
Notifications
You must be signed in to change notification settings - Fork 343
/
options.py
executable file
·157 lines (127 loc) · 8.85 KB
/
options.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import argparse
import os
class BaseOptions():
def __init__(self):
self.initialized = False
def initialize(self, parser):
# Datasets related
g_data = parser.add_argument_group('Data')
g_data.add_argument('--dataroot', type=str, default='./data',
help='path to images (data folder)')
g_data.add_argument('--loadSize', type=int, default=512, help='load size of input image')
# Experiment related
g_exp = parser.add_argument_group('Experiment')
g_exp.add_argument('--name', type=str, default='example',
help='name of the experiment. It decides where to store samples and models')
g_exp.add_argument('--debug', action='store_true', help='debug mode or not')
g_exp.add_argument('--num_views', type=int, default=1, help='How many views to use for multiview network.')
g_exp.add_argument('--random_multiview', action='store_true', help='Select random multiview combination.')
# Training related
g_train = parser.add_argument_group('Training')
g_train.add_argument('--gpu_id', type=int, default=0, help='gpu id for cuda')
g_train.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2, -1 for CPU mode')
g_train.add_argument('--num_threads', default=1, type=int, help='# sthreads for loading data')
g_train.add_argument('--serial_batches', action='store_true',
help='if true, takes images in order to make batches, otherwise takes them randomly')
g_train.add_argument('--pin_memory', action='store_true', help='pin_memory')
g_train.add_argument('--batch_size', type=int, default=2, help='input batch size')
g_train.add_argument('--learning_rate', type=float, default=1e-3, help='adam learning rate')
g_train.add_argument('--learning_rateC', type=float, default=1e-3, help='adam learning rate')
g_train.add_argument('--num_epoch', type=int, default=100, help='num epoch to train')
g_train.add_argument('--freq_plot', type=int, default=10, help='freqency of the error plot')
g_train.add_argument('--freq_save', type=int, default=50, help='freqency of the save_checkpoints')
g_train.add_argument('--freq_save_ply', type=int, default=100, help='freqency of the save ply')
g_train.add_argument('--no_gen_mesh', action='store_true')
g_train.add_argument('--no_num_eval', action='store_true')
g_train.add_argument('--resume_epoch', type=int, default=-1, help='epoch resuming the training')
g_train.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
# Testing related
g_test = parser.add_argument_group('Testing')
g_test.add_argument('--resolution', type=int, default=256, help='# of grid in mesh reconstruction')
g_test.add_argument('--test_folder_path', type=str, default=None, help='the folder of test image')
# Sampling related
g_sample = parser.add_argument_group('Sampling')
g_sample.add_argument('--sigma', type=float, default=5.0, help='perturbation standard deviation for positions')
g_sample.add_argument('--num_sample_inout', type=int, default=5000, help='# of sampling points')
g_sample.add_argument('--num_sample_color', type=int, default=0, help='# of sampling points')
g_sample.add_argument('--z_size', type=float, default=200.0, help='z normalization factor')
# Model related
g_model = parser.add_argument_group('Model')
# General
g_model.add_argument('--norm', type=str, default='group',
help='instance normalization or batch normalization or group normalization')
g_model.add_argument('--norm_color', type=str, default='instance',
help='instance normalization or batch normalization or group normalization')
# hg filter specify
g_model.add_argument('--num_stack', type=int, default=4, help='# of hourglass')
g_model.add_argument('--num_hourglass', type=int, default=2, help='# of stacked layer of hourglass')
g_model.add_argument('--skip_hourglass', action='store_true', help='skip connection in hourglass')
g_model.add_argument('--hg_down', type=str, default='ave_pool', help='ave pool || conv64 || conv128')
g_model.add_argument('--hourglass_dim', type=int, default='256', help='256 | 512')
# Classification General
g_model.add_argument('--mlp_dim', nargs='+', default=[257, 1024, 512, 256, 128, 1], type=int,
help='# of dimensions of mlp')
g_model.add_argument('--mlp_dim_color', nargs='+', default=[513, 1024, 512, 256, 128, 3],
type=int, help='# of dimensions of color mlp')
g_model.add_argument('--use_tanh', action='store_true',
help='using tanh after last conv of image_filter network')
# for train
parser.add_argument('--random_flip', action='store_true', help='if random flip')
parser.add_argument('--random_trans', action='store_true', help='if random flip')
parser.add_argument('--random_scale', action='store_true', help='if random flip')
parser.add_argument('--no_residual', action='store_true', help='no skip connection in mlp')
parser.add_argument('--schedule', type=int, nargs='+', default=[60, 80],
help='Decrease learning rate at these epochs.')
parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.')
parser.add_argument('--color_loss_type', type=str, default='l1', help='mse | l1')
# for eval
parser.add_argument('--val_test_error', action='store_true', help='validate errors of test data')
parser.add_argument('--val_train_error', action='store_true', help='validate errors of train data')
parser.add_argument('--gen_test_mesh', action='store_true', help='generate test mesh')
parser.add_argument('--gen_train_mesh', action='store_true', help='generate train mesh')
parser.add_argument('--all_mesh', action='store_true', help='generate meshs from all hourglass output')
parser.add_argument('--num_gen_mesh_test', type=int, default=1,
help='how many meshes to generate during testing')
# path
parser.add_argument('--checkpoints_path', type=str, default='./checkpoints', help='path to save checkpoints')
parser.add_argument('--load_netG_checkpoint_path', type=str, default=None, help='path to save checkpoints')
parser.add_argument('--load_netC_checkpoint_path', type=str, default=None, help='path to save checkpoints')
parser.add_argument('--results_path', type=str, default='./results', help='path to save results ply')
parser.add_argument('--load_checkpoint_path', type=str, help='path to save results ply')
parser.add_argument('--single', type=str, default='', help='single data for training')
# for single image reconstruction
parser.add_argument('--mask_path', type=str, help='path for input mask')
parser.add_argument('--img_path', type=str, help='path for input image')
# aug
group_aug = parser.add_argument_group('aug')
group_aug.add_argument('--aug_alstd', type=float, default=0.0, help='augmentation pca lighting alpha std')
group_aug.add_argument('--aug_bri', type=float, default=0.0, help='augmentation brightness')
group_aug.add_argument('--aug_con', type=float, default=0.0, help='augmentation contrast')
group_aug.add_argument('--aug_sat', type=float, default=0.0, help='augmentation saturation')
group_aug.add_argument('--aug_hue', type=float, default=0.0, help='augmentation hue')
group_aug.add_argument('--aug_blur', type=float, default=0.0, help='augmentation blur')
# special tasks
self.initialized = True
return parser
def gather_options(self):
# initialize parser with basic options
if not self.initialized:
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = self.initialize(parser)
self.parser = parser
return parser.parse_args()
def print_options(self, opt):
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(opt).items()):
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
print(message)
def parse(self):
opt = self.gather_options()
return opt