diff --git a/lcnn/models/line_vectorizer.py b/lcnn/models/line_vectorizer.py index 2a77f26..c4890d5 100644 --- a/lcnn/models/line_vectorizer.py +++ b/lcnn/models/line_vectorizer.py @@ -45,7 +45,7 @@ def __init__(self, backbone): def forward(self, input_dict): result = self.backbone(input_dict) - h = result["heatmaps"] + h = result["preds"] x = self.fc1(result["feature"]) n_batch, n_channel, row, col = x.shape @@ -134,16 +134,16 @@ def sum_batch(x): jcs[i][j] = jcs[i][j][ None, torch.arange(M.n_out_junc) % len(jcs[i][j]) ] - result["heatmaps"]["lines"] = torch.cat(lines) - result["heatmaps"]["score"] = torch.cat(score) - result["heatmaps"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)]) + result["preds"]["lines"] = torch.cat(lines) + result["preds"]["score"] = torch.cat(score) + result["preds"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)]) if len(jcs[i]) > 1: - result["heatmaps"]["junts"] = torch.cat( + result["preds"]["junts"] = torch.cat( [jcs[i][1] for i in range(n_batch)] ) else: - if "heatmaps" in result: - del result["heatmaps"] + if "preds" in result: + del result["preds"] return result def sample_lines(self, meta, jmap, joff, do_evaluation): diff --git a/lcnn/models/multitask_learner.py b/lcnn/models/multitask_learner.py index d76577f..e6730b5 100644 --- a/lcnn/models/multitask_learner.py +++ b/lcnn/models/multitask_learner.py @@ -61,7 +61,7 @@ def forward(self, input_dict, output_feature=True): lmap = output[offset[0] : offset[1]].squeeze(0) joff = output[offset[1] : offset[2]].reshape(n_jtyp, 2, batch, row, col) if stack == 0: - result["heatmaps"] = { + result["preds"] = { "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1], "lmap": lmap.sigmoid(), "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5, diff --git a/lcnn/trainer.py b/lcnn/trainer.py index 872528b..d05d017 100644 --- a/lcnn/trainer.py +++ b/lcnn/trainer.py @@ -122,7 +122,7 @@ def validate(self): total_loss += self._loss(result) - H = result["heatmaps"] + H = result["preds"] for i in range(H["jmap"].shape[0]): index = batch_idx * self.batch_size + i np.savez( diff --git a/process.py b/process.py index 279c81e..32c90f3 100755 --- a/process.py +++ b/process.py @@ -99,7 +99,7 @@ def main(): "target": recursive_to(target, device), "do_evaluation": True, } - H = model(input_dict)["heatmaps"] + H = model(input_dict)["preds"] for i in range(M.batch_size): index = batch_idx * M.batch_size + i np.savez(