diff --git a/train.py b/train.py index 44d4731..f856745 100755 --- a/train.py +++ b/train.py @@ -93,16 +93,18 @@ def main(): datadir = C.io.datadir kwargs = { - "batch_size": M.batch_size, "collate_fn": collate, "num_workers": C.io.num_workers, "pin_memory": True, } train_loader = torch.utils.data.DataLoader( - WireframeDataset(datadir, split="train"), shuffle=True, **kwargs + WireframeDataset(datadir, split="train"), + shuffle=True, + batch_size=M.batch_size, + **kwargs, ) val_loader = torch.utils.data.DataLoader( - WireframeDataset(datadir, split="valid"), shuffle=False, **kwargs + WireframeDataset(datadir, split="valid"), shuffle=False, batch_size=2, **kwargs ) epoch_size = len(train_loader) # print("epoch_size (train):", epoch_size)