-
Notifications
You must be signed in to change notification settings - Fork 3
/
train.py
110 lines (83 loc) · 3.93 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# coding: utf-8
import chainer.training.extensions as ex
from chainer.iterators import SerialIterator
from chainer.optimizer_hooks import WeightDecay
from chainer.optimizers import Adam
from chainer.training import PRIORITY_READER, Trainer
from datasets import get_dataset, get_unlabel_dataset
from discriminator import FCN
from functions import adam_lr_poly
from generator import DilatedFCN, ResNetDeepLab, UNet
from options import get_options
from updater import AdvSemiSeg_Updater
def train(opt):
if opt.use_cpu:
device = -1
print('[Message] use CPU')
else:
device = 0
print('[Message] use GPU0')
annotated = get_dataset(opt)
unlabeled = get_unlabel_dataset(opt)
print('[Message] loaded options')
train_iter = SerialIterator(annotated, opt.batch_size, shuffle=True)
print('[Message] converted to iterator (train)')
semi_iter = SerialIterator(unlabeled, opt.batch_size, shuffle=True)
print('[Message] converted to iterator (semi)')
gen = ResNetDeepLab(opt)
# gen = DilatedFCN(opt)
# gen = UNet(opt)
if device != -1:
gen.to_gpu(device) # use GPU
g_optim = Adam(alpha=opt.g_lr, beta1=opt.g_beta1, beta2=opt.g_beta2)
g_optim.setup(gen)
if opt.g_weight_decay > 0:
g_optim.add_hook(WeightDecay(opt.g_weight_decay))
print('[Message] setuped Generator')
dis = FCN(opt)
if device != -1:
dis.to_gpu(device) #use GPU
d_optim = Adam(alpha=opt.d_lr, beta1=opt.d_beta1, beta2=opt.d_beta2)
d_optim.setup(dis)
print('[Message] setuped Discriminator')
updater = AdvSemiSeg_Updater(opt,
iterator={'main': train_iter, 'semi': semi_iter},
optimizer={'gen': g_optim, 'dis': d_optim},
device=device)
print('[Message] initialized Updater')
trainer = Trainer(updater, (opt.max_epoch, 'epoch'), out=opt.out_dir)
print('[Message] initialized Trainer')
# chainer training extensions
trainer.extend(ex.LogReport(log_name=None, trigger=(1, 'iteration')))
trainer.extend(ex.ProgressBar((opt.max_epoch, 'epoch'), update_interval=1))
trainer.extend(ex.PlotReport(['gen/adv_loss', 'dis/adv_loss', 'gen/semi_adv_loss'],
x_key='iteration', file_name='adversarial_loss.png', trigger=(100, 'iteration')))
# test
trainer.extend(ex.PlotReport(['gen/adv_loss' ],
x_key='iteration', file_name='adv_gen_loss.png', trigger=(100, 'iteration')))
trainer.extend(ex.PlotReport(['gen/ce_loss'],
x_key='iteration', file_name='cross_entropy_loss.png', trigger=(100, 'iteration')))
trainer.extend(ex.PlotReport(['gen/semi_st_loss'],
x_key='iteration', file_name='self_teach_loss.png', trigger=(100, 'iteration')))
trainer.extend(ex.PlotReport(['gen/loss', 'dis/loss', 'gen/semi_loss'],
x_key='iteration', file_name='loss.png', trigger=(100, 'iteration')))
trainer.extend(ex.PlotReport(['gen/loss', 'dis/loss', 'gen/semi_loss'],
x_key='epoch', file_name='loss_details.png', trigger=(5, 'epoch')))
trainer.extend(ex.PlotReport(['gen/semi_loss'],
x_key='epoch', file_name='semi_loss.png', trigger=(1, 'epoch')))
# snap
trainer.extend(ex.snapshot_object(gen, 'gen_snapshot_epoch-{.updater.epoch}.npz'),
trigger=(opt.snap_interval_epoch, 'epoch'))
trainer.extend(ex.snapshot_object(dis, 'dis_snapshot_epoch-{.updater.epoch}.npz'),
trigger=(opt.snap_interval_epoch, 'epoch'))
trainer.extend(lambda *args: updater.save_img(),
trigger=(opt.img_interval_iteration, 'iteration'), priority=PRIORITY_READER)
trainer.extend(lambda *args: updater.ignition_semi_learning(),
trigger=(opt.semi_ignit_iteration, 'iteration'), priority=PRIORITY_READER)
trainer.extend(lambda *args: adam_lr_poly(opt, trainer), trigger=(100, 'iteration'))
print('[Message] initialized extension')
print('[Message] start training ...')
trainer.run() # start learning
if __name__ == '__main__':
opt = get_options()
train(opt)