Skip to content

Commit

Permalink
add finetuning for jepa
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaaaavin committed Nov 9, 2024
1 parent fd327df commit 44ec51c
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 19 deletions.
18 changes: 8 additions & 10 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
from pl_modules.citywalk_datamodule import CityWalkDataModule
from pl_modules.urbannav_datamodule import UrbanNavDataModule
from pl_modules.urban_nav_module import UrbanNavModule
from pytorch_lightning.strategies import DDPStrategy
from pl_modules.urbannav_jepa_module import UrbanNavJEPAModule
import torch
import glob

torch.set_float32_matmul_precision('medium')
pl.seed_everything(42, workers=True)
Expand Down Expand Up @@ -59,14 +58,13 @@ def main():
else:
raise ValueError(f"Invalid dataset: {cfg.data.type}")

# Initialize the model from checkpoint
model = UrbanNavModule.load_from_checkpoint(args.checkpoint, cfg=cfg)
# for param in model.model.compress_goal_enc.parameters():
# param.requires_grad = False
# for param in model.model.decoder.parameters():
# param.requires_grad = False
# for param in model.model.decoder.output_layers.parameters():
# param.requires_grad = True
# Initialize the model
if cfg.model.type == 'urbannav':
model = UrbanNavModule(cfg)
elif cfg.model.type == 'urbannav_jepa':
model = UrbanNavJEPAModule(cfg)
else:
raise ValueError(f"Invalid model: {cfg.model.type}")
print(f"Loaded model from checkpoint: {args.checkpoint}")
print(pl.utilities.model_summary.ModelSummary(model, max_depth=2))

Expand Down
16 changes: 11 additions & 5 deletions model/urban_nav_jepa.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,24 @@ def forward(self, obs, cord, future_obs=None):
"""
B, N, _, H, W = obs.shape
obs = obs.view(B * N, 3, H, W)
future_obs = future_obs.view(B * N, 3, H, W)
if future_obs:
future_obs = future_obs.view(B * N, 3, H, W)
if self.do_rgb_normalize:
obs = (obs - self.mean) / self.std
future_obs = (future_obs - self.mean) / self.std
if future_obs:
future_obs = (future_obs - self.mean) / self.std
if self.do_resize:
obs = TF.center_crop(obs, self. crop)
obs = TF.resize(obs, self.resize)
future_obs = TF.center_crop(future_obs, self.crop)
future_obs = TF.resize(future_obs, self.resize)
if future_obs:
future_obs = TF.center_crop(future_obs, self.crop)
future_obs = TF.resize(future_obs, self.resize)

obs_enc = self.obs_encoder(obs).view(B, N, -1)
future_obs_enc = self.obs_encoder(future_obs).view(B, N, -1)
if future_obs:
future_obs_enc = self.obs_encoder(future_obs).view(B, N, -1)
else:
future_obs_enc = None

# Coordinate Encoding
cord_enc = self.cord_embedding(cord).view(B, -1)
Expand Down
20 changes: 16 additions & 4 deletions pl_modules/urbannav_jepa_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ def forward(self, obs, cord, future_obs):
def training_step(self, batch, batch_idx):
obs = batch['video_frames']
cord = batch['input_positions']
future_obs = batch['future_video_frames']
if "future_video_frames" in batch:
future_obs = batch['future_video_frames']
else:
future_obs = None

wp_pred, arrive_pred, feature_pred, feature_gt = self(obs, cord, future_obs)
losses = self.compute_loss(wp_pred, arrive_pred, feature_pred, feature_gt, batch)
Expand All @@ -71,7 +74,10 @@ def training_step(self, batch, batch_idx):
def validation_step(self, batch, batch_idx):
obs = batch['video_frames']
cord = batch['input_positions']
future_obs = batch['future_video_frames']
if "future_video_frames" in batch:
future_obs = batch['future_video_frames']
else:
future_obs = None
wp_pred, arrive_pred, feature_pred, feature_gt = self(obs, cord, future_obs)
losses = self.compute_loss(wp_pred, arrive_pred, feature_pred, feature_gt, batch)
l1_loss = losses['waypoints_loss']
Expand Down Expand Up @@ -107,7 +113,10 @@ def validation_step(self, batch, batch_idx):
def test_step(self, batch, batch_idx):
obs = batch['video_frames']
cord = batch['input_positions']
future_obs = batch['future_video_frames']
if "future_video_frames" in batch:
future_obs = batch['future_video_frames']
else:
future_obs = None
B, T, _, _, _ = obs.shape

if self.datatype == "citywalk":
Expand Down Expand Up @@ -283,7 +292,10 @@ def on_test_epoch_start(self):
def compute_loss(self, wp_pred, arrive_pred, feature_pred, feature_gt, batch):
waypoints_target = batch['waypoints']
arrived_target = batch['arrived']
feature_loss = F.mse_loss(feature_pred, feature_gt)
if feature_pred is not None and feature_gt is not None:
feature_loss = F.mse_loss(feature_pred, feature_gt)
else:
feature_loss = 0.0
wp_loss = F.l1_loss(wp_pred, waypoints_target)
arrived_loss = F.binary_cross_entropy_with_logits(arrive_pred.flatten(), arrived_target)

Expand Down

0 comments on commit 44ec51c

Please sign in to comment.