Skip to content

Commit

Permalink
Split training, testing, validation
Browse files Browse the repository at this point in the history
  • Loading branch information
zhou13 committed Feb 4, 2020
1 parent ce9895e commit c57814e
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 14 deletions.
2 changes: 1 addition & 1 deletion demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def main():
"jmap": torch.zeros([1, 1, 128, 128]).to(device),
"joff": torch.zeros([1, 1, 2, 128, 128]).to(device),
},
"do_evaluation": True,
"mode": "testing",
}
H = model(input_dict)["preds"]

Expand Down
21 changes: 12 additions & 9 deletions lcnn/models/line_vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ def forward(self, input_dict):
xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
for i, meta in enumerate(input_dict["meta"]):
p, label, feat, jc = self.sample_lines(
meta, h["jmap"][i], h["joff"][i], input_dict["do_evaluation"]
meta, h["jmap"][i], h["joff"][i], input_dict["mode"]
)
# print("p.shape:", p.shape)
ys.append(label)
if not input_dict["do_evaluation"] and self.do_static_sampling:
if input_dict["mode"] == "training" and self.do_static_sampling:
p = torch.cat([p, meta["lpre"]])
feat = torch.cat([feat, meta["lpre_feat"]])
ys.append(meta["lpre_label"])
Expand Down Expand Up @@ -95,7 +95,7 @@ def forward(self, input_dict):
x = torch.cat([x, f], 1)
x = self.fc2(x).flatten()

if input_dict["do_evaluation"]:
if input_dict["mode"] != "training":
p = torch.cat(ps)
s = torch.sigmoid(x)
b = s > 0.5
Expand Down Expand Up @@ -128,7 +128,8 @@ def forward(self, input_dict):
result["preds"]["junts"] = torch.cat(
[jcs[i][1] for i in range(n_batch)]
)
else:

if input_dict["mode"] != "testing":
y = torch.cat(ys)
loss = self.loss(x, y)
lpos_mask, lneg_mask = y, 1 - y
Expand All @@ -142,11 +143,13 @@ def sum_batch(x):
lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
result["losses"][0]["lpos"] = lpos * M.loss_weight["lpos"]
result["losses"][0]["lneg"] = lneg * M.loss_weight["lneg"]

if input_dict["mode"] == "training":
del result["preds"]

return result

def sample_lines(self, meta, jmap, joff, do_evaluation):
def sample_lines(self, meta, jmap, joff, mode):
with torch.no_grad():
junc = meta["junc"] # [N, 2]
jtyp = meta["jtyp"] # [N]
Expand All @@ -158,7 +161,7 @@ def sample_lines(self, meta, jmap, joff, do_evaluation):
joff = joff.reshape(n_type, 2, -1)
max_K = M.n_dyn_junc // n_type
N = len(junc)
if do_evaluation:
if mode != "training":
K = min(int((jmap > M.eval_junc_thres).float().sum().item()), max_K)
else:
K = min(int(N * 2 + 2), max_K)
Expand Down Expand Up @@ -193,9 +196,7 @@ def sample_lines(self, meta, jmap, joff, do_evaluation):
up, vp = match[u], match[v]
label = Lpos[up, vp]

if do_evaluation:
c = (u < v).flatten()
else:
if mode == "training":
c = torch.zeros_like(label, dtype=torch.bool)

# sample positive lines
Expand All @@ -217,6 +218,8 @@ def sample_lines(self, meta, jmap, joff, do_evaluation):
# sample other (unmatched) lines
cdx = torch.randint(len(c), (M.n_dyn_othr,), device=device)
c[cdx] = 1
else:
c = (u < v).flatten()

# sample lines
u, v, label = u[c], v[c], label[c]
Expand Down
2 changes: 1 addition & 1 deletion lcnn/models/multitask_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def forward(self, input_dict):
"lmap": lmap.sigmoid(),
"joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
}
if input_dict["do_evaluation"]:
if input_dict["mode"] == "testing":
return result

L = OrderedDict()
Expand Down
4 changes: 2 additions & 2 deletions lcnn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def validate(self):
"image": recursive_to(image, self.device),
"meta": recursive_to(meta, self.device),
"target": recursive_to(target, self.device),
"do_evaluation": True,
"mode": "validation",
}
result = self.model(input_dict)

Expand Down Expand Up @@ -173,7 +173,7 @@ def train_epoch(self):
"image": recursive_to(image, self.device),
"meta": recursive_to(meta, self.device),
"target": recursive_to(target, self.device),
"do_evaluation": False,
"mode": "training",
}
result = self.model(input_dict)

Expand Down
2 changes: 1 addition & 1 deletion process.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def main():
"image": recursive_to(image, device),
"meta": recursive_to(meta, device),
"target": recursive_to(target, device),
"do_evaluation": True,
"mode": "validation",
}
H = model(input_dict)["preds"]
for i in range(M.batch_size):
Expand Down

0 comments on commit c57814e

Please sign in to comment.