Skip to content

Commit a00c8ea

Browse files
author
Matt Macy
committed
make training work again
1 parent 850f5b5 commit a00c8ea

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

train.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@
4343
#ct_targets = nodule_masks
4444

4545

46-
nodule_masks = "luna16_nodule_masks"
47-
lung_masks = "luna16_seg_lungs"
46+
nodule_masks = "normalized_brightened_CT_2_5"
47+
lung_masks = "inferred_seg_lungs_2_5"
4848
ct_images = "luna16_ct_normalized"
49-
ct_targets = lung_masks
49+
ct_targets = nodule_masks
5050
target_split = [2, 2, 2]
5151

5252
def weights_init(m):
@@ -99,6 +99,7 @@ def inference(args, loader, model, transforms):
9999

100100
def noop(x):
101101
return x
102+
102103
def main():
103104
parser = argparse.ArgumentParser()
104105
parser.add_argument('--batchSz', type=int, default=10)
@@ -194,7 +195,7 @@ def main():
194195
else:
195196
masks = None
196197

197-
if args.inference is not None:
198+
if args.inference != '':
198199
if not args.resume:
199200
print("args.resume must be set to do inference")
200201
exit(1)
@@ -221,7 +222,7 @@ def main():
221222
mode="test", transform=testTransform, seed=args.seed, masks=masks, split=target_split),
222223
batch_size=batch_size, shuffle=False, **kwargs)
223224

224-
target_mean = trainSet.target_weight()
225+
target_mean = trainSet.target_mean()
225226
bg_weight = target_mean / (1. + target_mean)
226227
fg_weight = 1. - bg_weight
227228
print(bg_weight)

0 commit comments

Comments
 (0)