-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathupdater.py
123 lines (93 loc) · 3.55 KB
/
updater.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# coding: utf-8
import os
from chainer.training import StandardUpdater
from chainer import Variable
import numpy as np
from PIL import Image
from loss import dis_loss, gen_loss, gen_semi_loss
from chainer.backends import cuda
from functions import onehot2label
class AdvSemiSeg_Updater(StandardUpdater):
def __init__(self, opt, *args, **kwargs):
super().__init__(*args, **kwargs)
self.opt = opt
self.num_saved_img = 0
self.learn_from_unlabel = False
self.img4save = None
self.semi_img4save = None
def update_core(self):
g_opt = self.get_optimizer('gen')
d_opt = self.get_optimizer('dis')
# predict
x, real_g = self.real_batch('main')
fake_g = g_opt.target(x)
self.img4save = [cuda.to_cpu(x.array[0]),
cuda.to_cpu(real_g.array[0]),
cuda.to_cpu(fake_g.array[0])]
real_d = d_opt.target(real_g)
fake_d = d_opt.target(fake_g)
# generator loss
g_loss = gen_loss(self.opt, fake_d, real_g, fake_g, observer=g_opt.target)
g_opt.target.cleargrads()
g_loss.backward()
g_opt.update()
# discriminator loss
x.unchain_backward()
fake_g.unchain_backward()
d_loss = dis_loss(self.opt, real_d, fake_d, observer=d_opt.target)
d_opt.target.cleargrads()
d_loss.backward()
d_opt.update()
if self.learn_from_unlabel:
# predict
unlabel_x, _ = self.real_batch('semi')
unlabel_g = g_opt.target(unlabel_x)
unlabel_d = d_opt.target(unlabel_g)
self.semi_img4save = [cuda.to_cpu(unlabel_x.array[0]),
None,
cuda.to_cpu(unlabel_g.array[0])]
# semi-supervised loss
semi_loss = gen_semi_loss(self.opt, unlabel_d, unlabel_g, observer=g_opt.target)
g_opt.target.cleargrads()
semi_loss.backward()
g_opt.update()
def real_batch(self, iter_key='main'):
batch = self.get_iterator(iter_key).next()
batch = self.converter(batch, self.device)
if isinstance(batch, tuple) or isinstance(batch, list):
x, t = batch
# 16bit -> 32bit (not use tensor core)
x = Variable(x.astype('float32'))
t = Variable(t.astype('float32'))
return x, t
x = Variable(batch.astype('float32'))
return x, None
def save_img(self):
if self.img4save is None:
return
lines = [self.img4save]
if self.learn_from_unlabel:
lines.append(self.semi_img4save)
tile_img = None
for l in lines:
for i, sect in enumerate(l):
# l[0] = (l[0] + 1) * 0.5
if sect is None:
l[i] = np.zeros_like(l[0])
continue
if i != 0:
sect = onehot2label(sect)
l[i] = np.transpose(sect, (1, 2, 0))
l = np.concatenate(l, axis=1)
if tile_img is None:
tile_img = l
else:
tile_img = np.concatenate((tile_img, l), axis=0)
out = np.uint8(tile_img * 255)
out = Image.fromarray(out)
out_dir_name = self.opt.out_dir + '/out_img'
os.makedirs(out_dir_name, exist_ok=True)
out.save(out_dir_name + '/' + str(self.num_saved_img) + '.png')
self.num_saved_img += 1
def ignition_semi_learning(self):
self.learn_from_unlabel = True