Skip to content

Commit

Permalink
More comments on the data format. Closes #2
Browse files Browse the repository at this point in the history
  • Loading branch information
zhou13 committed May 16, 2019
1 parent 36f172d commit 7aee92c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 16 deletions.
20 changes: 11 additions & 9 deletions dataset/wireframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def jid(jun):
vint0, vint1 = to_int(v0[:2] / 2), to_int(v1[:2] / 2)
rr, cc, value = skimage.draw.line_aa(*vint0, *vint1)
lneg.append([v0, v1, i0, i1, np.average(np.minimum(value, llmap[rr, cc]))])
# assert np.sum((v0 - v1) ** 2) > 0.01

assert len(lneg) != 0
lneg.sort(key=lambda l: -l[-1])
Expand All @@ -115,17 +114,20 @@ def jid(jun):
# plt.plot([junc[i0][1], junc[i1][1]], [junc[i0][0], junc[i1][0]])
# plt.show()

# For junc, lpos, and lneg that stores the junction coordinates, the last
# dimension is (y, x, t), where t represents the type of that junction. In
# the wireframe dataset, t is always zero.
np.savez_compressed(
f"{prefix}_label.npz",
aspect_ratio=image.shape[1] / image.shape[0],
jmap=jmap, # [J, H, W]
joff=joff, # [J, 2, H, W]
lmap=lmap, # [H, W]
junc=junc, # [Na, 3]
Lpos=Lpos, # [M, 2]
Lneg=Lneg, # [M, 2]
lpos=lpos, # [Np, 2, 3] (y, x, t) for the last dim
lneg=lneg, # [Nn, 2, 3]
jmap=jmap, # [J, H, W] Junction heat map
joff=joff, # [J, 2, H, W] Junction offset within each pixel
lmap=lmap, # [H, W] Line heat map with anti-aliasing
junc=junc, # [Na, 3] Junction coordinate
Lpos=Lpos, # [M, 2] Positive lines represented with junction indices
Lneg=Lneg, # [M, 2] Negative lines represented with junction indices
lpos=lpos, # [Np, 2, 3] Positive lines represented with junction coordinates
lneg=lneg, # [Nn, 2, 3] Negative lines represented with junction coordinates
)
cv2.imwrite(f"{prefix}.png", image)

Expand Down
21 changes: 14 additions & 7 deletions lcnn/datasets.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import os
import glob
import json
import math
import os
import random

import numpy as np
import torch
import numpy.linalg as LA
import torch
from skimage import io
from torch.utils.data import Dataset
from torch.utils.data.dataloader import default_collate
Expand All @@ -15,11 +15,7 @@


class WireframeDataset(Dataset):
def __init__(
self,
rootdir,
split,
):
def __init__(self, rootdir, split):
self.rootdir = rootdir
filelist = glob.glob(f"{rootdir}/{split}/*_label.npz")
filelist.sort()
Expand All @@ -39,6 +35,17 @@ def __getitem__(self, idx):
image = (image - M.image.mean) / M.image.stddev
image = np.rollaxis(image, 2).copy()

# npz["jmap"]: [J, H, W] Junction heat map
# npz["joff"]: [J, 2, H, W] Junction offset within each pixel
# npz["lmap"]: [H, W] Line heat map with anti-aliasing
# npz["junc"]: [Na, 3] Junction coordinates
# npz["Lpos"]: [M, 2] Positive lines represented with junction indices
# npz["Lneg"]: [M, 2] Negative lines represented with junction indices
# npz["lpos"]: [Np, 2, 3] Positive lines represented with junction coordinates
# npz["lneg"]: [Nn, 2, 3] Negative lines represented with junction coordinates
#
# For junc, lpos, and lneg that stores the junction coordinates, the last
# dimension is (y, x, t), where t represents the type of that junction.
with np.load(self.filelist[idx]) as npz:
target = {
name: torch.from_numpy(npz[name]).float()
Expand Down

0 comments on commit 7aee92c

Please sign in to comment.