forked from Janspiry/Palette-Image-to-Image-Diffusion-Models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path__init__.py
executable file
·29 lines (20 loc) · 1.01 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from core.praser import init_obj
def create_model(**cfg_model):
""" create_model """
opt = cfg_model['opt']
logger = cfg_model['logger']
model_opt = opt['model']['which_model']
model_opt['args'].update(cfg_model)
model = init_obj(model_opt, logger, default_file_name='models.model', init_type='Model')
return model
def define_network(logger, opt, network_opt):
""" define network with weights initialization """
net = init_obj(network_opt, logger, default_file_name='models.network', init_type='Network')
if opt['phase'] == 'train':
logger.info('Network [{}] weights initialize using [{:s}] method.'.format(net.__class__.__name__, network_opt['args'].get('init_type', 'default')))
net.init_weights()
return net
def define_loss(logger, loss_opt):
return init_obj(loss_opt, logger, default_file_name='models.loss', init_type='Loss')
def define_metric(logger, metric_opt):
return init_obj(metric_opt, logger, default_file_name='models.metric', init_type='Metric')