-
Notifications
You must be signed in to change notification settings - Fork 7
/
train_image_seg.py
287 lines (246 loc) · 9.03 KB
/
train_image_seg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
import os
import traceback
import sys
import argparse
import time
import gc
import warnings
import torch
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.utils.metrics import IoU
from segmentation_models_pytorch.utils.losses import DiceLoss
import matplotlib.pyplot as plt
from pathlib import Path
from torch.utils import data
from image_module.dataset_water import WaterDataset_RGB
ROOT_DIR = './'
# time_str = time.strftime("%Y-%m-%d %H-%M-%S")
DEFAULT_CHKPT_DIR = os.path.join(ROOT_DIR, 'output', 'img_seg_checkpoint')
# # Device
# DEVICE = torch.device('cpu')
# if torch.cuda.is_available():
# DEVICE = torch.device('cuda')
# Input size must be a multiple of 32 as the image will be subsampled 5 times
def train(args):
"""
Executes train script given arguments
:param args: Training parameters
:return:
"""
try:
torch.cuda.empty_cache()
except:
print("Error clearing cache.")
print(traceback.format_exc())
dataset_path = args.dataset_path
input_shape = args.input_shape
batch_size = args.batch_size
init_lr = args.init_lr
epochs = args.epochs
out_path = args.out_path
encoder_name = args.encoder
# train_dir = os.path.join(dataset_path, 'train')
train_dir = os.path.join(dataset_path, '')
# val_dir = os.path.join(dataset_path, 'val')
val_dir = os.path.join(dataset_path, '')
# Input size must be a multiple of 32 as the image will be subsampled 5 times
train_dataset = WaterDataset_RGB(
mode='train_offline',
dataset_path=train_dir,
input_size=(416, 416)
)
val_dataset = WaterDataset_RGB(
mode='train_offline',
dataset_path=val_dir,
input_size=(input_shape, input_shape)
)
train_loader = data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4
)
val_loader = data.DataLoader(
val_dataset,
batch_size=1,
shuffle=False,
num_workers=4
)
linknet_model = smp.Linknet(
encoder_name=encoder_name,
encoder_depth=5,
encoder_weights='imagenet',
in_channels=3,
classes=1,
activation='sigmoid'
)
# Train LinkNet Model with given backbone
try:
train_model(
linknet_model,
init_lr=init_lr,
num_epochs=epochs,
out_path=out_path,
train_loader=train_loader,
val_loader=val_loader,
encoder_name=encoder_name
)
except:
print(traceback.format_exc())
try:
linknet_model = None
gc.collect()
except:
print(traceback.format_exc())
def train_model(model, init_lr, num_epochs, out_path, train_loader, val_loader, encoder_name):
"""
Trains a single image given model and further arguments
:param model: Model from SMP library
:param init_lr: Initial learning rate
:param num_epochs: Number of epochs to train
:param out_path: Folder to output checkpoints and model
:param train_loader: Dataloader for train dataset
:param val_loader: Dataloader for validation dataset
:return:
"""
plots_dir = os.path.join(out_path, 'graphs')
checkpoints_dir = os.path.join(out_path, 'checkpoints')
models_dir = os.path.join(out_path, 'model')
if not os.path.exists(plots_dir):
os.makedirs(plots_dir)
if not os.path.exists(checkpoints_dir):
os.makedirs(checkpoints_dir)
if not os.path.exists(models_dir):
os.makedirs(models_dir)
loss = DiceLoss()
metrics = [
IoU(threshold=0.5),
]
optimizer = torch.optim.Adam([
dict(params=model.parameters(), lr=init_lr),
])
# Create training epoch
train_epoch = smp.utils.train.TrainEpoch(
model,
loss=loss,
metrics=metrics,
optimizer=optimizer,
device=device,
verbose=True
)
# Create validation epoch
valid_epoch = smp.utils.train.ValidEpoch(
model,
loss=loss,
metrics=metrics,
device=device,
verbose=True
)
max_score = 0
train_iou_score_ls = []
train_dice_loss_ls = []
val_iou_score_ls = []
val_dice_loss_ls = []
# Go through each epoch
for epoch in range(0, num_epochs):
title = 'Epoch: {}'.format(epoch)
print('\nEpoch: {}'.format(epoch))
# Epoch logs
train_logs = train_epoch.run(train_loader)
valid_logs = valid_epoch.run(val_loader)
# Checkpoint to resume training
checkpoint = {
'epoch': epoch,
'weights': model.state_dict(),
'optimizer': optimizer.state_dict(),
'loss': loss.state_dict()
}
# Get IOU score
score = float(valid_logs['iou_score'])
checkpoint_savepth = os.path.join(checkpoints_dir, 'epoch_' + str(epoch).zfill(3) + '_score' + str(score) + '.pth')
torch.save(checkpoint, checkpoint_savepth)
# Check score on valid dataset
if score > max_score:
max_score = score
model_savepth = os.path.join(models_dir, 'linknet_' + encoder_name + '_epoch_' + str(epoch).zfill(3) + '_score' + str(score) + '.pth')
torch.save(model, model_savepth)
print('New best model detected.')
# Adjust learning rate halfway through training.
if epoch == int(num_epochs / 2):
optimizer.param_groups[0]['lr'] = 1e-5
print('Decrease decoder learning rate to 1e-5!')
train_iou_score_ls.append(train_logs['iou_score'])
train_dice_loss_ls.append(train_logs['dice_loss'])
val_iou_score_ls.append(valid_logs['iou_score'])
val_dice_loss_ls.append(valid_logs['dice_loss'])
plot_train_filepth = os.path.join(plots_dir, 'epoch_' + str(epoch).zfill(3) + '_train.png')
plot_val_filepth = os.path.join(plots_dir, 'epoch_' + str(epoch).zfill(3) + '_val.png')
plt.plot(train_iou_score_ls, label='train iou_score')
plt.plot(train_dice_loss_ls, label='train dice_loss')
plt.legend(loc="upper left")
plt.title(title)
plt.savefig(plot_train_filepth)
plt.close()
plt.plot(val_iou_score_ls, label='val iou_score')
plt.plot(val_dice_loss_ls, label='val dice_loss')
plt.legend(loc="upper left")
plt.title(title)
plt.savefig(plot_val_filepth)
plt.close()
"""
python train_segmodel.py --dataset_path
"""
if __name__ == '__main__':
# Hyper parameters
parser = argparse.ArgumentParser(description='PyTorch WaterNet Model Testing')
# Required: Path to the .pth file.
parser.add_argument('--dataset-path',
type=str,
metavar='PATH',
help='Path to the dataset. Expects format shown in the header comments.')
# Required: Model name. Can be efficient
parser.add_argument('--encoder',
type=str,
metavar='PATH',
help='Encoder name, as used by segmentation_model.pytorch library')
# Optional: Image input size that the model should be designed to accept. In LinkNet, image will be
# subsampled 5 times, and thus must be a factor of 32.
parser.add_argument('--input-shape',
default=416,
type=int,
help='(OPTIONAL) Input size for model. Single integer, should be a factor of 32.')
# Optional: Batch size for mini-batch gradient descent. Defaults to 4, depends on GPU and your input shape.
parser.add_argument('--batch-size',
default=4,
type=int,
help='(OPTIONAL) Batch size for mini-batch gradient descent.')
# Initial Learning Rate: Initial learning rate. Learning gets set to 1e-5 halfway through training.
parser.add_argument('--init-lr',
default=1e-4,
type=float,
help='(OPTIONAL) Batch size for mini-batch gradient descent.')
# Optional: Number of epochs for training
parser.add_argument('--epochs',
default=300,
type=int,
help='(OPTIONAL) Number of epochs for training')
# Optional: Which folder the checkpoints will be saved. Defaults to a new checkpoint folder in output.
parser.add_argument('--out-path',
default=DEFAULT_CHKPT_DIR,
type=str,
metavar='PATH',
help='(OPTIONAL) Path to output folder, defaults to project root/output')
_args = parser.parse_args()
print("== System Details ==")
print(torch.cuda.is_available())
print(torch.cuda.current_device())
print(torch.cuda.device(0))
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))
print("== System Details ==")
print()
device = torch.device('cpu')
if torch.cuda.is_available():
device = torch.device('cuda')
train(_args)
print("Done.")