Skip to content

Commit

Permalink
Add batch_size_eval. Closes #20.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhou13 committed Feb 11, 2020
1 parent 65ac416 commit c044515
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
1 change: 1 addition & 0 deletions config/wireframe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ model:
stddev: [22.275, 22.124, 23.229]

batch_size: 6
batch_size_eval: 2

# backbone multi-task parameters
head_size: [[2], [1], [2]]
Expand Down
8 changes: 4 additions & 4 deletions lcnn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from skimage import io
from tensorboardX import SummaryWriter

from lcnn.config import C
from lcnn.config import C, M
from lcnn.utils import recursive_to


Expand Down Expand Up @@ -103,8 +103,8 @@ def validate(self):
training = self.model.training
self.model.eval()

viz = osp.join(self.out, "viz", f"{self.iteration * self.batch_size:09d}")
npz = osp.join(self.out, "npz", f"{self.iteration * self.batch_size:09d}")
viz = osp.join(self.out, "viz", f"{self.iteration * M.batch_size_eval:09d}")
npz = osp.join(self.out, "npz", f"{self.iteration * M.batch_size_eval:09d}")
osp.exists(viz) or os.makedirs(viz)
osp.exists(npz) or os.makedirs(npz)

Expand All @@ -124,7 +124,7 @@ def validate(self):

H = result["preds"]
for i in range(H["jmap"].shape[0]):
index = batch_idx * self.batch_size + i
index = batch_idx * M.batch_size_eval + i
np.savez(
f"{npz}/{index:06}.npz",
**{k: v[i].cpu().numpy() for k, v in H.items()},
Expand Down
5 changes: 4 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,10 @@ def main():
**kwargs,
)
val_loader = torch.utils.data.DataLoader(
WireframeDataset(datadir, split="valid"), shuffle=False, batch_size=2, **kwargs
WireframeDataset(datadir, split="valid"),
shuffle=False,
batch_size=M.batch_size_eval,
**kwargs,
)
epoch_size = len(train_loader)
# print("epoch_size (train):", epoch_size)
Expand Down

0 comments on commit c044515

Please sign in to comment.