Skip to content

Commit

Permalink
code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaaaavin committed Nov 22, 2024
1 parent ccfe1db commit 8a6397d
Show file tree
Hide file tree
Showing 13 changed files with 57 additions and 70 deletions.
20 changes: 10 additions & 10 deletions config/citywalk_2000hr.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
project:
name: 'UrbanNav_Citywalk'
run_name: 'dinov2_2000_hr_jepa'
name: 'CityWalker'
run_name: 'train_2000hr'
result_dir: 'results'

training:
batch_size: 32
max_epochs: 8
max_epochs: 10
gpus: 1
amp: false
normalize_step_length: true
Expand All @@ -15,15 +15,15 @@ training:

scheduler:
name: 'cosine'
step_size: 8
step_size: 10
gamma: 0.1

optimizer:
name: 'adamw'
lr: 2e-4

model:
type: 'urbannav_jepa'
type: 'citywalker_feat'
do_rgb_normalize: true
do_resize: true
obs_encoder:
Expand All @@ -49,9 +49,9 @@ model:
ff_dim_factor: 4

data:
type: citywalk_jepa
video_dir: '/vast/xl3136/citywalk_2min/videos'
pose_dir: '/vast/xl3136/citywalk_2min/poses'
type: citywalk_feat
video_dir: 'dataset/citywalk_2min/videos'
pose_dir: 'dataset/citywalk_2min/poses'
num_workers: 23
pose_fps: 5
video_fps: 30
Expand All @@ -65,9 +65,9 @@ data:
arrived_prob: 0.3

validation:
num_visualize: 0
num_visualize: 400
testing:
num_visualize: 0
num_visualize: 400

logging:
enable_wandb: true # Set to false to disable Wandb logging
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __iter__(self):
def __len__(self):
return len(self.indices)

class CityWalkJEPADataset(Dataset):
class CityWalkFeatDataset(Dataset):
def __init__(self, cfg, mode):
super().__init__()
self.cfg = cfg
Expand Down
15 changes: 1 addition & 14 deletions data/urbannav_dataset.py → data/teleop_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import random
from PIL import Image

class UrbanNavDataset(Dataset):
class TeleopDataset(Dataset):
def __init__(self, cfg, mode):
super().__init__()
self.cfg = cfg
Expand All @@ -23,14 +23,7 @@ def __init__(self, cfg, mode):
self.arrived_threshold = cfg.data.arrived_threshold
self.arrived_prob = cfg.data.arrived_prob

# Load pose paths
# self.pose_path = [
# os.path.join(self.pose_dir, f)
# for f in sorted(os.listdir(self.pose_dir))
# if f.startswith('match_gps_pose') and f.endswith('.txt')
# ]
pose_files = [
# "pose_label_1.txt",
"pose_label_6.txt",
"pose_label_7.txt",
"pose_label_8.txt",
Expand Down Expand Up @@ -211,8 +204,6 @@ def __getitem__(self, index):
input_positions = input_positions @ rot_matrix.T
elif self.cfg.model.cord_embedding.type == 'input_target':
input_positions = self.transform_input(input_gps_positions)
# print(input_positions.shape)
# print(target_transformed.shape)
input_positions = np.concatenate([input_positions, target_transformed[np.newaxis, :2]], axis=0)
else:
raise NotImplementedError(f"Coordinate embedding type {self.cfg.model.cord_embedding.type} not implemented")
Expand All @@ -238,10 +229,6 @@ def __getitem__(self, index):
'arrived': arrived,
'step_scale': step_scale
}
# print("input", input_positions)
# print("history", history_positions)
# print("wp", waypoints_transformed)
# print("target", target_transformed)

if self.mode in ['val', 'test']:
# For visualization
Expand Down
12 changes: 6 additions & 6 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import yaml
import os
from pl_modules.citywalk_datamodule import CityWalkDataModule
from pl_modules.urbannav_datamodule import UrbanNavDataModule
from pl_modules.urban_nav_module import UrbanNavModule
from pl_modules.urbannav_jepa_module import UrbanNavJEPAModule
from pl_modules.teleop_datamodule import TeleopDataModule
from pl_modules.citywalker_module import CityWalkerModule
from pl_modules.citywalker_feat_module import CityWalkerFeatModule
import torch

torch.set_float32_matmul_precision('medium')
Expand Down Expand Up @@ -54,15 +54,15 @@ def main():
if cfg.data.type == 'citywalk':
datamodule = CityWalkDataModule(cfg)
elif cfg.data.type == 'urbannav':
datamodule = UrbanNavDataModule(cfg)
datamodule = TeleopDataModule(cfg)
else:
raise ValueError(f"Invalid dataset: {cfg.data.type}")

# Initialize the model
if cfg.model.type == 'urbannav':
model = UrbanNavModule.load_from_checkpoint(args.checkpoint, cfg=cfg)
model = CityWalkerModule.load_from_checkpoint(args.checkpoint, cfg=cfg)
elif cfg.model.type == 'urbannav_jepa':
model = UrbanNavJEPAModule.load_from_checkpoint(args.checkpoint, cfg=cfg)
model = CityWalkerFeatModule.load_from_checkpoint(args.checkpoint, cfg=cfg)
else:
raise ValueError(f"Invalid model: {cfg.model.type}")
print(f"Loaded model from checkpoint: {args.checkpoint}")
Expand Down
2 changes: 1 addition & 1 deletion model/urban_nav.py → model/citywalker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from model.model_utils import PolarEmbedding, MultiLayerDecoder, PositionalEncoding
from torchvision import models

class UrbanNav(nn.Module):
class CityWalker(nn.Module):
def __init__(self, cfg):
super().__init__()
self.context_size = cfg.model.obs_encoder.context_size
Expand Down
6 changes: 3 additions & 3 deletions model/urban_nav_jepa.py → model/citywalker_feat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
from efficientnet_pytorch import EfficientNet
from model.model_utils import PolarEmbedding, JEPAPredictor, PositionalEncoding
from model.model_utils import PolarEmbedding, FeatPredictor, PositionalEncoding
from torchvision import models

class UrbanNavJEPA(nn.Module):
class CityWalkerFeat(nn.Module):
def __init__(self, cfg):
super().__init__()
self.context_size = cfg.model.obs_encoder.context_size
Expand Down Expand Up @@ -52,7 +52,7 @@ def __init__(self, cfg):
raise NotImplementedError(f"Coordinate embedding type {self.cord_embedding_type} not implemented")

# Decoder
self.predictor = JEPAPredictor(
self.predictor = FeatPredictor(
embed_dim=self.num_obs_features,
seq_len=self.context_size+1,
nhead=cfg.model.decoder.num_heads,
Expand Down
2 changes: 1 addition & 1 deletion model/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def forward(self, x):
x = F.relu(x)
return x

class JEPAPredictor(nn.Module):
class FeatPredictor(nn.Module):
def __init__(self, embed_dim=512, seq_len=6, nhead=8, num_layers=8, ff_dim_factor=4):
super().__init__()
self.positional_encoding = PositionalEncoding(embed_dim, max_seq_len=seq_len)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from data.citywalk_dataset import CityWalkSampler
from data.citywalk_jepa_dataset import CityWalkJEPADataset
from data.citywalk_feat_dataset import CityWalkFeatDataset

class CityWalkJEPADataModule(pl.LightningDataModule):
class CityWalkFeatDataModule(pl.LightningDataModule):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
Expand All @@ -14,11 +14,11 @@ def __init__(self, cfg):

def setup(self, stage=None):
if stage == 'fit' or stage is None:
self.train_dataset = CityWalkJEPADataset(self.cfg, mode='train')
self.val_dataset = CityWalkJEPADataset(self.cfg, mode='val')
self.train_dataset = CityWalkFeatDataset(self.cfg, mode='train')
self.val_dataset = CityWalkFeatDataset(self.cfg, mode='val')

if stage == 'test' or stage is None:
self.test_dataset = CityWalkJEPADataset(self.cfg, mode='test')
self.test_dataset = CityWalkFeatDataset(self.cfg, mode='test')

def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
import numpy as np
import torch
import torch.nn.functional as F
from model.urban_nav_jepa import UrbanNavJEPA
from model.citywalker_feat import CityWalkerFeat
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('Agg')
plt.style.use('seaborn-v0_8')
import os

class UrbanNavJEPAModule(pl.LightningModule):
class CityWalkerFeatModule(pl.LightningModule):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.model = UrbanNavJEPA(cfg)
self.model = CityWalkerFeat(cfg)
self.save_hyperparameters(cfg)
self.do_normalize = cfg.training.normalize_step_length
self.datatype = cfg.data.type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@
import numpy as np
import torch
import torch.nn.functional as F
from model.urban_nav import UrbanNav
from model.citywalker import CityWalker
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('Agg')
import os

class UrbanNavModule(pl.LightningModule):
class CityWalkerModule(pl.LightningModule):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.model = UrbanNav(cfg)
self.model = CityWalker(cfg)
self.save_hyperparameters(cfg)
self.do_normalize = cfg.training.normalize_step_length
self.datatype = cfg.data.type
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pytorch_lightning as pl
from torch.utils.data import DataLoader
# from data.citywalk_dataset import CityWalkDataset
from data.urbannav_dataset import UrbanNavDataset
from data.teleop_dataset import TeleopDataset

class UrbanNavDataModule(pl.LightningDataModule):
class TeleopDataModule(pl.LightningDataModule):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
Expand All @@ -12,11 +12,11 @@ def __init__(self, cfg):

def setup(self, stage=None):
if stage == 'fit' or stage is None:
self.train_dataset = UrbanNavDataset(self.cfg, mode='train')
self.val_dataset = UrbanNavDataset(self.cfg, mode='val')
self.train_dataset = TeleopDataset(self.cfg, mode='train')
self.val_dataset = TeleopDataset(self.cfg, mode='val')

if stage == 'test' or stage is None:
self.test_dataset = UrbanNavDataset(self.cfg, mode='test')
self.test_dataset = TeleopDataset(self.cfg, mode='test')

def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size,
Expand Down
12 changes: 6 additions & 6 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import yaml
import os
from pl_modules.citywalk_datamodule import CityWalkDataModule
from pl_modules.urbannav_datamodule import UrbanNavDataModule
from pl_modules.urban_nav_module import UrbanNavModule
from pl_modules.urbannav_jepa_module import UrbanNavJEPAModule
from pl_modules.teleop_datamodule import TeleopDataModule
from pl_modules.citywalker_module import CityWalkerModule
from pl_modules.citywalker_feat_module import CityWalkerFeatModule
import torch
import glob

Expand Down Expand Up @@ -76,7 +76,7 @@ def main():
if cfg.data.type == 'citywalk':
datamodule = CityWalkDataModule(cfg)
elif cfg.data.type == 'urbannav':
datamodule = UrbanNavDataModule(cfg)
datamodule = TeleopDataModule(cfg)
else:
raise ValueError(f"Invalid dataset: {cfg.data.dataset}")

Expand All @@ -95,9 +95,9 @@ def main():

# Load the model from the checkpoint
if cfg.model.type == 'urbannav':
model = UrbanNavModule.load_from_checkpoint(checkpoint_path, cfg=cfg)
model = CityWalkerModule.load_from_checkpoint(checkpoint_path, cfg=cfg)
elif cfg.model.type == 'urbannav_jepa':
model = UrbanNavJEPAModule.load_from_checkpoint(checkpoint_path, cfg=cfg)
model = CityWalkerFeatModule.load_from_checkpoint(checkpoint_path, cfg=cfg)
else:
raise ValueError(f"Invalid model: {cfg.model.type}")
model.result_dir = test_dir
Expand Down
24 changes: 12 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import yaml
import os
from pl_modules.citywalk_datamodule import CityWalkDataModule
from pl_modules.urbannav_datamodule import UrbanNavDataModule
from pl_modules.urban_nav_module import UrbanNavModule
from pl_modules.urbannav_jepa_module import UrbanNavJEPAModule
from pl_modules.citywalk_jepa_datamodule import CityWalkJEPADataModule
from pl_modules.teleop_datamodule import TeleopDataModule
from pl_modules.citywalker_module import CityWalkerModule
from pl_modules.citywalker_feat_module import CityWalkerFeatModule
from pl_modules.citywalk_feat_datamodule import CityWalkFeatDataModule
from pytorch_lightning.strategies import DDPStrategy
import torch
import glob
Expand Down Expand Up @@ -80,18 +80,18 @@ def main():
# Initialize the DataModule
if cfg.data.type == 'citywalk':
datamodule = CityWalkDataModule(cfg)
elif cfg.data.type == 'urbannav':
datamodule = UrbanNavDataModule(cfg)
elif cfg.data.type == 'citywalk_jepa':
datamodule = CityWalkJEPADataModule(cfg)
elif cfg.data.type == 'teleop':
datamodule = TeleopDataModule(cfg)
elif cfg.data.type == 'citywalk_feat':
datamodule = CityWalkFeatDataModule(cfg)
else:
raise ValueError(f"Invalid dataset: {cfg.data.dataset}")

# Initialize the model
if cfg.model.type == 'urbannav':
model = UrbanNavModule(cfg)
elif cfg.model.type == 'urbannav_jepa':
model = UrbanNavJEPAModule(cfg)
if cfg.model.type == 'citywalker':
model = CityWalkerModule(cfg)
elif cfg.model.type == 'citywalker_feat':
model = CityWalkerFeatModule(cfg)
else:
raise ValueError(f"Invalid model: {cfg.model.type}")
print(pl.utilities.model_summary.ModelSummary(model, max_depth=2))
Expand Down

0 comments on commit 8a6397d

Please sign in to comment.