diff --git a/config/pretrain/toy.yaml b/config/pretrain/toy.yaml new file mode 100644 index 0000000..e7182dd --- /dev/null +++ b/config/pretrain/toy.yaml @@ -0,0 +1,48 @@ +project: + name: 'UrbanNav_Citywalk' + run_name: 'toy_data_pretrain' + result_dir: 'results' + +training: + batch_size: 8 + max_epochs: 10 + gpus: 1 + amp: true + normalize_step_length: true + resume: true + direction_loss_weight: 1.0 + reg_coeff: 0.1 + +scheduler: + name: 'cosine' + step_size: 10 + gamma: 0.1 + +optimizer: + name: 'adamw' + lr: 1e-4 + +model: + do_rgb_normalize: true + do_resize: true + obs_encoder: + type: 'vit_l_16' + context_size: 5 + target_size: 5 + encoder_feat_dim: 1024 + ema_range: [0.998, 1] + +data: + type: citywalk + video_dir: '/home/xinhao/citywalk/test_videos' + num_workers: 6 + pose_fps: 5 + video_fps: 30 + target_fps: 1 + num_train: 40 + num_val: 20 + num_test: 20 + +logging: + enable_wandb: false # Set to false to disable Wandb logging + pbar_rate: 1 \ No newline at end of file diff --git a/data/citywalk_pretrain_dataset.py b/data/citywalk_pretrain_dataset.py new file mode 100644 index 0000000..f8f3259 --- /dev/null +++ b/data/citywalk_pretrain_dataset.py @@ -0,0 +1,129 @@ +import os +import numpy as np +import torch +from torch.utils.data import Dataset +from decord import VideoReader, cpu +import torch.nn.functional as F +from tqdm import tqdm + +class VideoJEPADataset(Dataset): + def __init__(self, cfg, mode): + super().__init__() + self.cfg = cfg + self.mode = mode + self.video_dir = cfg.data.video_dir + self.context_size = cfg.model.obs_encoder.context_size # Number of input frames + self.target_size = cfg.model.obs_encoder.target_size # Number of target frames + self.video_fps = cfg.data.video_fps + self.target_fps = cfg.data.target_fps + self.frame_multiplier = self.video_fps // self.target_fps + + # Load video paths + self.video_paths = [ + os.path.join(self.video_dir, f) + for f in sorted(os.listdir(self.video_dir)) + if f.endswith('.mp4') + ] + print(f"Total videos found: {len(self.video_paths)}") + + # Split videos according to mode + if mode == 'train': + self.video_paths = self.video_paths[:cfg.data.num_train] + elif mode == 'val': + self.video_paths = self.video_paths[cfg.data.num_train: cfg.data.num_train + cfg.data.num_val] + elif mode == 'test': + self.video_paths = self.video_paths[cfg.data.num_train + cfg.data.num_val: + cfg.data.num_train + cfg.data.num_val + cfg.data.num_test] + else: + raise ValueError(f"Invalid mode {mode}") + + print(f"Number of videos for {mode}: {len(self.video_paths)}") + + # Build the look-up table (lut) and video_ranges + self.lut = [] + self.video_ranges = [] + idx_counter = 0 + for video_idx, video_path in enumerate(tqdm(self.video_paths, desc="Building LUT")): + # Initialize VideoReader to get number of frames + vr = VideoReader(video_path, ctx=cpu(0)) + num_frames = len(vr) + usable_frames = num_frames // self.frame_multiplier - (self.context_size + self.target_size) + if usable_frames <= 0: + continue # Skip videos that are too short + start_idx = idx_counter + for frame_start in range(0, usable_frames, self.context_size): + self.lut.append((video_idx, frame_start)) + idx_counter += 1 + end_idx = idx_counter + self.video_ranges.append((start_idx, end_idx)) + assert len(self.lut) > 0, "No usable samples found." + + print(f"Total samples in LUT: {len(self.lut)}") + print(f"Total video ranges: {len(self.video_ranges)}") + + # Initialize the video reader cache per worker + self.video_reader_cache = {'video_idx': None, 'video_reader': None} + + def __len__(self): + return len(self.lut) + + def __getitem__(self, index): + video_idx, frame_start = self.lut[index] + + # Retrieve or create the VideoReader for the current video + if self.video_reader_cache['video_idx'] != video_idx: + # Replace the old VideoReader with the new one + self.video_reader_cache['video_reader'] = VideoReader(self.video_paths[video_idx], ctx=cpu(0)) + self.video_reader_cache['video_idx'] = video_idx + video_reader = self.video_reader_cache['video_reader'] + + # Compute actual frame indices for input and target + actual_frame_start = frame_start * self.frame_multiplier + frame_indices_input = actual_frame_start + np.arange(self.context_size) * self.frame_multiplier + frame_indices_target = actual_frame_start + (self.context_size + np.arange(self.target_size)) * self.frame_multiplier + + # Ensure frame indices are within the video length + num_frames = len(video_reader) + frame_indices_input = [min(idx, num_frames - 1) for idx in frame_indices_input] + frame_indices_target = [min(idx, num_frames - 1) for idx in frame_indices_target] + + # Load the frames + input_frames = video_reader.get_batch(frame_indices_input).asnumpy() + target_frames = video_reader.get_batch(frame_indices_target).asnumpy() + + # Process frames + input_frames = self.process_frames(input_frames) + target_frames = self.process_frames(target_frames) + + sample = { + 'input_frames': input_frames, # Shape: (context_size, 3, H, W) + 'target_frames': target_frames # Shape: (target_size, 3, H, W) + } + + return sample + + def process_frames(self, frames): + """ + Convert frames to tensor, normalize, and resize if necessary. + + Args: + frames (numpy.ndarray): Array of frames with shape (N, H, W, C). + + Returns: + torch.Tensor: Processed frames with shape (N, 3, H, W). + """ + # Convert frames to tensor and normalize + frames = torch.tensor(frames).permute(0, 3, 1, 2).float() / 255.0 # Shape: (N, 3, H, W) + + # # Optional resizing + # if self.cfg.data.do_resize: + # frames = F.interpolate(frames, size=(self.cfg.data.resize_height, self.cfg.data.resize_width), + # mode='bilinear', align_corners=False) + + # # Optional normalization + # if self.cfg.data.do_rgb_normalize: + # mean = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1).to(frames.device) + # std = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1).to(frames.device) + # frames = (frames - mean) / std + + return frames diff --git a/model/vjepa.py b/model/vjepa.py new file mode 100644 index 0000000..51b6394 --- /dev/null +++ b/model/vjepa.py @@ -0,0 +1,124 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import copy +from torchvision import models +from efficientnet_pytorch import EfficientNet + +class ImageEncoder(nn.Module): + def __init__(self, cfg): + super().__init__() + self.obs_encoder_type = cfg.model.obs_encoder.type + self.do_rgb_normalize = cfg.model.do_rgb_normalize + self.do_resize = cfg.model.do_resize + self.encoder_feat_dim = cfg.model.encoder_feat_dim + + if self.do_rgb_normalize: + self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + # Build the observation encoder + if self.obs_encoder_type.startswith("efficientnet"): + self.obs_encoder = EfficientNet.from_name(self.obs_encoder_type, in_channels=3) + self.num_obs_features = self.obs_encoder._fc.in_features + self.obs_encoder._fc = nn.Identity() + elif self.obs_encoder_type.startswith("resnet"): + model_constructor = getattr(models, self.obs_encoder_type) + self.obs_encoder = model_constructor(pretrained=False) + self.num_obs_features = self.obs_encoder.fc.in_features + self.obs_encoder.fc = nn.Identity() + elif self.obs_encoder_type.startswith("vit"): + model_constructor = getattr(models, self.obs_encoder_type) + self.obs_encoder = model_constructor(pretrained=False) + self.num_obs_features = self.obs_encoder.hidden_dim + self.obs_encoder.heads = nn.Identity() + else: + raise NotImplementedError(f"Observation encoder type {self.obs_encoder_type} not implemented") + + # Compress observation encodings to encoder_feat_dim + if self.num_obs_features != self.encoder_feat_dim: + self.compress_obs_enc = nn.Linear(self.num_obs_features, self.encoder_feat_dim) + else: + self.compress_obs_enc = nn.Identity() + + def forward(self, obs): + """ + Args: + obs: (B*N, 3, H, W) + Returns: + embeddings: (B*N, encoder_feat_dim) + """ + # Pre-processing + if self.do_rgb_normalize: + obs = (obs - self.mean) / self.std + if self.do_resize: + obs = F.interpolate(obs, size=(224, 224), mode='bilinear', align_corners=False) + + # Observation Encoding + if self.obs_encoder_type.startswith("efficientnet"): + obs_enc = self.obs_encoder.extract_features(obs) + obs_enc = self.obs_encoder._avg_pooling(obs_enc) + obs_enc = obs_enc.flatten(start_dim=1) + elif self.obs_encoder_type.startswith("resnet"): + x = self.obs_encoder.conv1(obs) + x = self.obs_encoder.bn1(x) + x = self.obs_encoder.relu(x) + x = self.obs_encoder.maxpool(x) + x = self.obs_encoder.layer1(x) + x = self.obs_encoder.layer2(x) + x = self.obs_encoder.layer3(x) + x = self.obs_encoder.layer4(x) + x = self.obs_encoder.avgpool(x) + obs_enc = torch.flatten(x, 1) + elif self.obs_encoder_type.startswith("vit"): + obs_enc = self.obs_encoder(obs) # Returns class token embedding + else: + raise NotImplementedError(f"Observation encoder type {self.obs_encoder_type} not implemented") + + # Compress embeddings + obs_enc = self.compress_obs_enc(obs_enc) + + return obs_enc + +class VideoJEPA(nn.Module): + def __init__(self, cfg): + super().__init__() + + # Initialize the online encoder + self.online_encoder = ImageEncoder(cfg) + + # Initialize the target encoder as a copy of the online encoder + self.target_encoder = copy.deepcopy(self.online_encoder) + # Freeze the target encoder parameters (they will be updated via EMA) + for param in self.target_encoder.parameters(): + param.requires_grad = False + + @torch.no_grad() + def update_target_encoder(self, decay=0.998): + # EMA update for the target encoder + for param_o, param_t in zip(self.online_encoder.parameters(), self.target_encoder.parameters()): + param_t.data.mul_(decay).add_(param_o.data * (1 - decay)) + + def forward(self, input_obs, target_obs): + """ + Args: + input_obs: (B, N_in, 3, H, W) tensor of past frames + target_obs: (B, N_target, 3, H, W) tensor of future frames + Returns: + online_embeddings: (B, N_in, encoder_feat_dim) + target_embeddings: (B, N_target, encoder_feat_dim) + """ + # Process input observations through online encoder + B, N_in, C, H, W = input_obs.shape + input_obs_flat = input_obs.view(B * N_in, C, H, W) + online_embeddings_flat = self.online_encoder(input_obs_flat) + online_embeddings = online_embeddings_flat.view(B, N_in, -1) + + # Process target observations through target encoder with stop gradient + with torch.no_grad(): + B, N_target, C, H, W = target_obs.shape + target_obs_flat = target_obs.view(B * N_target, C, H, W) + target_embeddings_flat = self.target_encoder(target_obs_flat) + target_embeddings = target_embeddings_flat.view(B, N_target, -1) + + return online_embeddings, target_embeddings diff --git a/pl_modules/citywalk_pretrain_datamodule.py b/pl_modules/citywalk_pretrain_datamodule.py new file mode 100644 index 0000000..9c2e780 --- /dev/null +++ b/pl_modules/citywalk_pretrain_datamodule.py @@ -0,0 +1,33 @@ +# data/datamodule.py + +import pytorch_lightning as pl +from torch.utils.data import DataLoader +from data.citywalk_dataset import CityWalkSampler +from data.citywalk_pretrain_dataset import VideoJEPADataset + +class CityWalkPretrainDataModule(pl.LightningDataModule): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.batch_size = cfg.training.batch_size + self.num_workers = cfg.data.num_workers + + def setup(self, stage=None): + if stage == 'fit' or stage is None: + self.train_dataset = VideoJEPADataset(self.cfg, mode='train') + self.val_dataset = VideoJEPADataset(self.cfg, mode='val') + + if stage == 'test' or stage is None: + self.test_dataset = VideoJEPADataset(self.cfg, mode='test') + + def train_dataloader(self): + return DataLoader(self.train_dataset, batch_size=self.batch_size, + num_workers=self.num_workers, sampler=CityWalkSampler(self.train_dataset)) + + def val_dataloader(self): + return DataLoader(self.val_dataset, batch_size=self.batch_size, + num_workers=self.num_workers, shuffle=False) + + def test_dataloader(self): + return DataLoader(self.test_dataset, batch_size=self.batch_size, + num_workers=self.num_workers, shuffle=False) diff --git a/pl_modules/vjepa_module.py b/pl_modules/vjepa_module.py new file mode 100644 index 0000000..76682b7 --- /dev/null +++ b/pl_modules/vjepa_module.py @@ -0,0 +1,142 @@ +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR +from model.vjepa import VideoJEPA # Adjust the import path as needed + +class VJEPAModule(pl.LightningModule): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.save_hyperparameters(cfg) + + # Initialize the Video-JEPA Encoder model + self.model = VideoJEPA(cfg) + + # EMA schedule parameters + self.ema_start = cfg.model.ema_range[0] # Starting EMA decay value (e.g., 0.99) + self.ema_end = cfg.model.ema_range[1] # Ending EMA decay value (e.g., 1.0) + self.total_epochs = cfg.training.max_epochs # Total number of training epochs + + # Regularization coefficient + self.reg_coeff = cfg.training.reg_coeff # e.g., 0.1 + + def forward(self, input_obs, target_obs): + return self.model(input_obs, target_obs) + + def training_step(self, batch, batch_idx): + input_frames = batch['input_frames'] # Shape: (B, N_in, 3, H, W) + target_frames = batch['target_frames'] # Shape: (B, N_target, 3, H, W) + + # Forward pass + online_embeddings, target_embeddings = self(input_frames, target_frames) + + # Compute JEPA loss (L1 loss) + loss_jepa = F.l1_loss(online_embeddings, target_embeddings) + + # Compute regularization loss + pstd_z = torch.sqrt(online_embeddings.var(dim=1) + 1e-4) # Shape: (B, feature_dim) + loss_reg = F.relu(1.0 - pstd_z).mean() + + # Total loss + loss = loss_jepa + self.reg_coeff * loss_reg + + # Log the losses + self.log('train/loss_jepa', loss_jepa, on_step=True, on_epoch=True, prog_bar=False, sync_dist=True) + self.log('train/loss_reg', loss_reg, on_step=True, on_epoch=True, prog_bar=False, sync_dist=True) + self.log('train/loss_total', loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + + # Update the target encoder with EMA + self.model.update_target_encoder(self.ema_momentum) + + return loss + + def validation_step(self, batch, batch_idx): + input_frames = batch['input_frames'] + target_frames = batch['target_frames'] + + # Forward pass + online_embeddings, target_embeddings = self(input_frames, target_frames) + + # Compute JEPA loss (L1 loss) + loss_jepa = F.l1_loss(online_embeddings, target_embeddings) + + # Compute regularization loss + pstd_z = torch.sqrt(online_embeddings.var(dim=1) + 1e-4) + loss_reg = F.relu(1.0 - pstd_z).mean() + + # Total loss + loss = loss_jepa + self.reg_coeff * loss_reg + + # Log the validation losses + self.log('val/loss_jepa', loss_jepa, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) + self.log('val/loss_reg', loss_reg, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) + self.log('val/loss_total', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) + + def configure_optimizers(self): + optimizer_name = self.cfg.optimizer.name.lower() + lr = float(self.cfg.optimizer.lr) + + if optimizer_name == 'adam': + optimizer = torch.optim.Adam( + self.model.online_encoder.parameters(), + lr=lr, + weight_decay=self.cfg.optimizer.weight_decay + ) + elif optimizer_name == 'sgd': + optimizer = torch.optim.SGD( + self.model.online_encoder.parameters(), + lr=lr, + weight_decay=self.cfg.optimizer.weight_decay + ) + elif optimizer_name == 'adamw': + optimizer = torch.optim.AdamW( + self.model.online_encoder.parameters(), + lr=lr, + ) + else: + raise ValueError(f"Unsupported optimizer type: {self.cfg.optimizer.name}") + + # Configure scheduler + scheduler_cfg = self.cfg.scheduler + if scheduler_cfg.name.lower() == 'step_lr': + scheduler = StepLR( + optimizer, + step_size=scheduler_cfg.step_size, + gamma=scheduler_cfg.gamma + ) + return [optimizer], [scheduler] + elif scheduler_cfg.name.lower() == 'cosine': + scheduler = CosineAnnealingLR( + optimizer, + T_max=self.total_epochs + ) + return [optimizer], [scheduler] + elif scheduler_cfg.name.lower() == 'none': + return optimizer + else: + raise ValueError(f"Unsupported scheduler type: {scheduler_cfg.name}") + + def on_train_epoch_start(self): + # Compute EMA decay for the current epoch + m = self.compute_ema_decay(self.current_epoch) + self.ema_momentum = m + + def compute_ema_decay(self, epoch_num): + """ + Compute the EMA momentum based on the current epoch using a linear schedule. + + Args: + epoch_num (int): Current epoch number + + Returns: + float: Computed EMA momentum + """ + m_start = self.ema_start + m_end = self.ema_end + total_epochs = self.total_epochs + + # Linear interpolation of momentum + m = m_start + (m_end - m_start) * (epoch_num / total_epochs) + m = min(m, m_end) # Ensure m does not exceed m_end + return m diff --git a/pretrain.py b/pretrain.py new file mode 100644 index 0000000..47ce5a4 --- /dev/null +++ b/pretrain.py @@ -0,0 +1,173 @@ +# main.py + +import pytorch_lightning as pl +import argparse +import yaml +import os +from pl_modules.citywalk_pretrain_datamodule import CityWalkPretrainDataModule +from pl_modules.vjepa_module import VJEPAModule +from pytorch_lightning.strategies import DDPStrategy +import torch +import glob +torch.set_float32_matmul_precision('medium') + + +# Remove the WandbLogger import from the top +# from pytorch_lightning.loggers import WandbLogger + +class DictNamespace(argparse.Namespace): + def __init__(self, **kwargs): + for key, value in kwargs.items(): + if isinstance(value, dict): + setattr(self, key, DictNamespace(**value)) + else: + setattr(self, key, value) + +def parse_args(): + parser = argparse.ArgumentParser(description='Train UrbanNav model') + parser.add_argument('--config', type=str, default='config/default.yaml', help='Path to config file') + parser.add_argument('--checkpoint', type=str, default=None, help='Path to model checkpoint. If not provided, the latest checkpoint will be used.') + args = parser.parse_args() + return args + +def load_config(config_path): + with open(config_path, 'r') as f: + cfg_dict = yaml.safe_load(f) + cfg = DictNamespace(**cfg_dict) + return cfg + +def find_latest_checkpoint(checkpoint_dir): + """ + Finds the latest checkpoint in the given directory based on modification time. + + Args: + checkpoint_dir (str): Path to the directory containing checkpoints. + + Returns: + str: Path to the latest checkpoint file. + + Raises: + FileNotFoundError: If no checkpoint files are found in the directory. + """ + print(checkpoint_dir) + checkpoint_pattern = os.path.join(checkpoint_dir, '*.ckpt') + checkpoint_files = glob.glob(checkpoint_pattern) + if not checkpoint_files: + raise FileNotFoundError(f"No checkpoint files found in directory: {checkpoint_dir}") + + # Sort checkpoints by modification time (latest first) + checkpoint_files.sort(key=lambda x: os.path.getmtime(x), reverse=True) + latest_checkpoint = checkpoint_files[0] + return latest_checkpoint + +def main(): + args = parse_args() + cfg = load_config(args.config) + + # Create result directory + result_dir = os.path.join(cfg.project.result_dir, cfg.project.run_name) + os.makedirs(result_dir, exist_ok=True) + cfg.project.result_dir = result_dir # Update result_dir in cfg + + # Save config file in result directory + with open(os.path.join(result_dir, 'config.yaml'), 'w') as f: + yaml.dump(cfg.__dict__, f) + + # Initialize the DataModule + if cfg.data.type == 'citywalk': + datamodule = CityWalkPretrainDataModule(cfg) + else: + raise ValueError(f"Invalid dataset: {cfg.data.dataset}") + + # Initialize the model + model = VJEPAModule(cfg) + print(pl.utilities.model_summary.ModelSummary(model, max_depth=2)) + + # Initialize logger + logger = None # Default to no logger + + # Check if logging with Wandb is enabled in config + use_wandb = cfg.logging.enable_wandb + + if use_wandb: + try: + from pytorch_lightning.loggers import WandbLogger # Import here to handle ImportError + wandb_logger = WandbLogger( + project=cfg.project.name, + name=cfg.project.run_name, + save_dir=result_dir + ) + logger = wandb_logger + print("WandbLogger initialized.") + except ImportError: + print("Wandb is not installed. Skipping Wandb logging.") + + checkpoint_callback = pl.callbacks.ModelCheckpoint( + dirpath=os.path.join(result_dir, 'checkpoints'), + save_last=True, + save_top_k=1, + monitor='val/l1_loss', + ) + + num_gpu = torch.cuda.device_count() + # num_gpu = 1 + + # Set up Trainer + if num_gpu > 1: + trainer = pl.Trainer( + default_root_dir=result_dir, + max_epochs=cfg.training.max_epochs, + logger=logger, # Pass the logger (WandbLogger or None) + devices=num_gpu, + precision='16-mixed' if cfg.training.amp else 32, + accelerator='gpu', + callbacks=[ + checkpoint_callback, + pl.callbacks.TQDMProgressBar(refresh_rate=cfg.logging.pbar_rate), + ], + log_every_n_steps=1, + strategy=DDPStrategy(find_unused_parameters=True) + ) + else: + trainer = pl.Trainer( + default_root_dir=result_dir, + max_epochs=cfg.training.max_epochs, + logger=logger, # Pass the logger (WandbLogger or None) + devices=num_gpu, + precision='16-mixed' if cfg.training.amp else 32, + accelerator='gpu', + callbacks=[ + checkpoint_callback, + pl.callbacks.TQDMProgressBar(refresh_rate=cfg.logging.pbar_rate), + ], + log_every_n_steps=1, + ) + + if cfg.training.resume: + # Determine the checkpoint path + try: + if args.checkpoint: + checkpoint_path = args.checkpoint + if not os.path.isfile(checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}") + else: + # Automatically find the latest checkpoint + checkpoint_dir = os.path.join(cfg.project.result_dir, 'checkpoints') + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Checkpoint directory does not exist: {checkpoint_dir}") + checkpoint_path = os.path.join(checkpoint_dir, 'last.ckpt') + if not os.path.isfile(checkpoint_path): + raise FileNotFoundError() + else: + print(f"No checkpoint specified. Using the latest checkpoint: {checkpoint_path}") + print(f"Training resume from checkpoint: {checkpoint_path}") + except FileNotFoundError: + print("No checkpoint found. Training from scratch.") + checkpoint_path = None + trainer.fit(model, datamodule=datamodule, ckpt_path=checkpoint_path) + else: + # Start training + trainer.fit(model, datamodule=datamodule) + +if __name__ == '__main__': + main()