Skip to content

Commit

Permalink
Update demo.py
Browse files Browse the repository at this point in the history
  • Loading branch information
zhou13 committed Feb 1, 2020
1 parent fd4e2fc commit ce9895e
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 51 deletions.
109 changes: 58 additions & 51 deletions demo.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#!/usr/bin/env python3
"""Process an image with the trained neural network
Usage:
demo.py [options] <yaml-config> <checkpoint> <image>
demo.py [options] <yaml-config> <checkpoint> <images>...
demo.py (-h | --help )
Arguments:
<yaml-config> Path to the yaml hyper-parameter file
<checkpoint> Path to the checkpoint
<image> Path to the directory containing processed images
<images> Path to images
Options:
-h --help Show this screen.
Expand Down Expand Up @@ -83,55 +83,62 @@ def main():
model = model.to(device)
model.eval()

im = skimage.io.imread(args["<image>"])[:, :, :3]
im_resized = skimage.transform.resize(im, (512, 512)) * 255
image = (im_resized - M.image.mean) / M.image.stddev
image = torch.from_numpy(np.rollaxis(image, 2)[None].copy()).float()
with torch.no_grad():
input_dict = {
"image": image.to(device),
"meta": [
{
"junc": torch.zeros(1, 2).to(device),
"jtyp": torch.zeros(1, dtype=torch.uint8).to(device),
"Lpos": torch.zeros(2, 2, dtype=torch.uint8).to(device),
"Lneg": torch.zeros(2, 2, dtype=torch.uint8).to(device),
}
],
"target": {
"jmap": torch.zeros([1, 1, 128, 128]).to(device),
"joff": torch.zeros([1, 1, 2, 128, 128]).to(device),
},
"do_evaluation": True,
}
H = model(input_dict)["preds"]

lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
scores = H["score"][0].cpu().numpy()
for i in range(1, len(lines)):
if (lines[i] == lines[0]).all():
lines = lines[:i]
scores = scores[:i]
break

# postprocess lines to remove overlapped lines
diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)

plt.gca().set_axis_off()
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
for i, t in enumerate([0.95, 0.96, 0.97, 0.98, 0.99]):
for (a, b), s in zip(nlines, nscores):
if s < t:
continue
plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
plt.scatter(a[1], a[0], **PLTOPTS)
plt.scatter(b[1], b[0], **PLTOPTS)
plt.imshow(im)
plt.show()
for imname in args["<images>"]:
print(f"Processing {imname}")
im = skimage.io.imread(imname)
if im.ndim == 2:
im = np.repeat(im[:, :, None], 3, 2)
im = im[:, :, :3]
im_resized = skimage.transform.resize(im, (512, 512)) * 255
image = (im_resized - M.image.mean) / M.image.stddev
image = torch.from_numpy(np.rollaxis(image, 2)[None].copy()).float()
with torch.no_grad():
input_dict = {
"image": image.to(device),
"meta": [
{
"junc": torch.zeros(1, 2).to(device),
"jtyp": torch.zeros(1, dtype=torch.uint8).to(device),
"Lpos": torch.zeros(2, 2, dtype=torch.uint8).to(device),
"Lneg": torch.zeros(2, 2, dtype=torch.uint8).to(device),
}
],
"target": {
"jmap": torch.zeros([1, 1, 128, 128]).to(device),
"joff": torch.zeros([1, 1, 2, 128, 128]).to(device),
},
"do_evaluation": True,
}
H = model(input_dict)["preds"]

lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
scores = H["score"][0].cpu().numpy()
for i in range(1, len(lines)):
if (lines[i] == lines[0]).all():
lines = lines[:i]
scores = scores[:i]
break

# postprocess lines to remove overlapped lines
diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)

for i, t in enumerate([0.94, 0.95, 0.96, 0.97, 0.98, 0.99]):
plt.gca().set_axis_off()
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
for (a, b), s in zip(nlines, nscores):
if s < t:
continue
plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
plt.scatter(a[1], a[0], **PLTOPTS)
plt.scatter(b[1], b[0], **PLTOPTS)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.imshow(im)
plt.savefig(imname.replace(".png", f"-{t:.02f}.svg"), bbox_inches="tight")
plt.show()
plt.close()


if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions lcnn/models/line_vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ def sample_lines(self, meta, jmap, joff, do_evaluation):
K = min(int((jmap > M.eval_junc_thres).float().sum().item()), max_K)
else:
K = min(int(N * 2 + 2), max_K)
if K < 2:
K = 2
device = jmap.device

# index: [N_TYPE, K]
Expand Down

0 comments on commit ce9895e

Please sign in to comment.