-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathteleop_datamodule.py
31 lines (25 loc) · 1.22 KB
/
teleop_datamodule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import pytorch_lightning as pl
from torch.utils.data import DataLoader
# from data.citywalk_dataset import CityWalkDataset
from data.teleop_dataset import TeleopDataset
class TeleopDataModule(pl.LightningDataModule):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.batch_size = cfg.training.batch_size
self.num_workers = cfg.data.num_workers
def setup(self, stage=None):
if stage == 'fit' or stage is None:
self.train_dataset = TeleopDataset(self.cfg, mode='train')
self.val_dataset = TeleopDataset(self.cfg, mode='val')
if stage == 'test' or stage is None:
self.test_dataset = TeleopDataset(self.cfg, mode='test')
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size,
num_workers=self.num_workers, shuffle=True)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size,
num_workers=self.num_workers, shuffle=False)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batch_size,
num_workers=self.num_workers, shuffle=False)