Skip to content

Commit

Permalink
add pretraining code
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaaaavin committed Oct 25, 2024
1 parent 5a75cc9 commit 5735c8e
Show file tree
Hide file tree
Showing 6 changed files with 649 additions and 0 deletions.
48 changes: 48 additions & 0 deletions config/pretrain/toy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
project:
name: 'UrbanNav_Citywalk'
run_name: 'toy_data_pretrain'
result_dir: 'results'

training:
batch_size: 8
max_epochs: 10
gpus: 1
amp: true
normalize_step_length: true
resume: true
direction_loss_weight: 1.0
reg_coeff: 0.1

scheduler:
name: 'cosine'
step_size: 10
gamma: 0.1

optimizer:
name: 'adamw'
lr: 1e-4

model:
do_rgb_normalize: true
do_resize: true
obs_encoder:
type: 'vit_l_16'
context_size: 5
target_size: 5
encoder_feat_dim: 1024
ema_range: [0.998, 1]

data:
type: citywalk
video_dir: '/home/xinhao/citywalk/test_videos'
num_workers: 6
pose_fps: 5
video_fps: 30
target_fps: 1
num_train: 40
num_val: 20
num_test: 20

logging:
enable_wandb: false # Set to false to disable Wandb logging
pbar_rate: 1
129 changes: 129 additions & 0 deletions data/citywalk_pretrain_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from decord import VideoReader, cpu
import torch.nn.functional as F
from tqdm import tqdm

class VideoJEPADataset(Dataset):
def __init__(self, cfg, mode):
super().__init__()
self.cfg = cfg
self.mode = mode
self.video_dir = cfg.data.video_dir
self.context_size = cfg.model.obs_encoder.context_size # Number of input frames
self.target_size = cfg.model.obs_encoder.target_size # Number of target frames
self.video_fps = cfg.data.video_fps
self.target_fps = cfg.data.target_fps
self.frame_multiplier = self.video_fps // self.target_fps

# Load video paths
self.video_paths = [
os.path.join(self.video_dir, f)
for f in sorted(os.listdir(self.video_dir))
if f.endswith('.mp4')
]
print(f"Total videos found: {len(self.video_paths)}")

# Split videos according to mode
if mode == 'train':
self.video_paths = self.video_paths[:cfg.data.num_train]
elif mode == 'val':
self.video_paths = self.video_paths[cfg.data.num_train: cfg.data.num_train + cfg.data.num_val]
elif mode == 'test':
self.video_paths = self.video_paths[cfg.data.num_train + cfg.data.num_val:
cfg.data.num_train + cfg.data.num_val + cfg.data.num_test]
else:
raise ValueError(f"Invalid mode {mode}")

print(f"Number of videos for {mode}: {len(self.video_paths)}")

# Build the look-up table (lut) and video_ranges
self.lut = []
self.video_ranges = []
idx_counter = 0
for video_idx, video_path in enumerate(tqdm(self.video_paths, desc="Building LUT")):
# Initialize VideoReader to get number of frames
vr = VideoReader(video_path, ctx=cpu(0))
num_frames = len(vr)
usable_frames = num_frames // self.frame_multiplier - (self.context_size + self.target_size)
if usable_frames <= 0:
continue # Skip videos that are too short
start_idx = idx_counter
for frame_start in range(0, usable_frames, self.context_size):
self.lut.append((video_idx, frame_start))
idx_counter += 1
end_idx = idx_counter
self.video_ranges.append((start_idx, end_idx))
assert len(self.lut) > 0, "No usable samples found."

print(f"Total samples in LUT: {len(self.lut)}")
print(f"Total video ranges: {len(self.video_ranges)}")

# Initialize the video reader cache per worker
self.video_reader_cache = {'video_idx': None, 'video_reader': None}

def __len__(self):
return len(self.lut)

def __getitem__(self, index):
video_idx, frame_start = self.lut[index]

# Retrieve or create the VideoReader for the current video
if self.video_reader_cache['video_idx'] != video_idx:
# Replace the old VideoReader with the new one
self.video_reader_cache['video_reader'] = VideoReader(self.video_paths[video_idx], ctx=cpu(0))
self.video_reader_cache['video_idx'] = video_idx
video_reader = self.video_reader_cache['video_reader']

# Compute actual frame indices for input and target
actual_frame_start = frame_start * self.frame_multiplier
frame_indices_input = actual_frame_start + np.arange(self.context_size) * self.frame_multiplier
frame_indices_target = actual_frame_start + (self.context_size + np.arange(self.target_size)) * self.frame_multiplier

# Ensure frame indices are within the video length
num_frames = len(video_reader)
frame_indices_input = [min(idx, num_frames - 1) for idx in frame_indices_input]
frame_indices_target = [min(idx, num_frames - 1) for idx in frame_indices_target]

# Load the frames
input_frames = video_reader.get_batch(frame_indices_input).asnumpy()
target_frames = video_reader.get_batch(frame_indices_target).asnumpy()

# Process frames
input_frames = self.process_frames(input_frames)
target_frames = self.process_frames(target_frames)

sample = {
'input_frames': input_frames, # Shape: (context_size, 3, H, W)
'target_frames': target_frames # Shape: (target_size, 3, H, W)
}

return sample

def process_frames(self, frames):
"""
Convert frames to tensor, normalize, and resize if necessary.
Args:
frames (numpy.ndarray): Array of frames with shape (N, H, W, C).
Returns:
torch.Tensor: Processed frames with shape (N, 3, H, W).
"""
# Convert frames to tensor and normalize
frames = torch.tensor(frames).permute(0, 3, 1, 2).float() / 255.0 # Shape: (N, 3, H, W)

# # Optional resizing
# if self.cfg.data.do_resize:
# frames = F.interpolate(frames, size=(self.cfg.data.resize_height, self.cfg.data.resize_width),
# mode='bilinear', align_corners=False)

# # Optional normalization
# if self.cfg.data.do_rgb_normalize:
# mean = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1).to(frames.device)
# std = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1).to(frames.device)
# frames = (frames - mean) / std

return frames
124 changes: 124 additions & 0 deletions model/vjepa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from torchvision import models
from efficientnet_pytorch import EfficientNet

class ImageEncoder(nn.Module):
def __init__(self, cfg):
super().__init__()
self.obs_encoder_type = cfg.model.obs_encoder.type
self.do_rgb_normalize = cfg.model.do_rgb_normalize
self.do_resize = cfg.model.do_resize
self.encoder_feat_dim = cfg.model.encoder_feat_dim

if self.do_rgb_normalize:
self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

# Build the observation encoder
if self.obs_encoder_type.startswith("efficientnet"):
self.obs_encoder = EfficientNet.from_name(self.obs_encoder_type, in_channels=3)
self.num_obs_features = self.obs_encoder._fc.in_features
self.obs_encoder._fc = nn.Identity()
elif self.obs_encoder_type.startswith("resnet"):
model_constructor = getattr(models, self.obs_encoder_type)
self.obs_encoder = model_constructor(pretrained=False)
self.num_obs_features = self.obs_encoder.fc.in_features
self.obs_encoder.fc = nn.Identity()
elif self.obs_encoder_type.startswith("vit"):
model_constructor = getattr(models, self.obs_encoder_type)
self.obs_encoder = model_constructor(pretrained=False)
self.num_obs_features = self.obs_encoder.hidden_dim
self.obs_encoder.heads = nn.Identity()
else:
raise NotImplementedError(f"Observation encoder type {self.obs_encoder_type} not implemented")

# Compress observation encodings to encoder_feat_dim
if self.num_obs_features != self.encoder_feat_dim:
self.compress_obs_enc = nn.Linear(self.num_obs_features, self.encoder_feat_dim)
else:
self.compress_obs_enc = nn.Identity()

def forward(self, obs):
"""
Args:
obs: (B*N, 3, H, W)
Returns:
embeddings: (B*N, encoder_feat_dim)
"""
# Pre-processing
if self.do_rgb_normalize:
obs = (obs - self.mean) / self.std
if self.do_resize:
obs = F.interpolate(obs, size=(224, 224), mode='bilinear', align_corners=False)

# Observation Encoding
if self.obs_encoder_type.startswith("efficientnet"):
obs_enc = self.obs_encoder.extract_features(obs)
obs_enc = self.obs_encoder._avg_pooling(obs_enc)
obs_enc = obs_enc.flatten(start_dim=1)
elif self.obs_encoder_type.startswith("resnet"):
x = self.obs_encoder.conv1(obs)
x = self.obs_encoder.bn1(x)
x = self.obs_encoder.relu(x)
x = self.obs_encoder.maxpool(x)
x = self.obs_encoder.layer1(x)
x = self.obs_encoder.layer2(x)
x = self.obs_encoder.layer3(x)
x = self.obs_encoder.layer4(x)
x = self.obs_encoder.avgpool(x)
obs_enc = torch.flatten(x, 1)
elif self.obs_encoder_type.startswith("vit"):
obs_enc = self.obs_encoder(obs) # Returns class token embedding
else:
raise NotImplementedError(f"Observation encoder type {self.obs_encoder_type} not implemented")

# Compress embeddings
obs_enc = self.compress_obs_enc(obs_enc)

return obs_enc

class VideoJEPA(nn.Module):
def __init__(self, cfg):
super().__init__()

# Initialize the online encoder
self.online_encoder = ImageEncoder(cfg)

# Initialize the target encoder as a copy of the online encoder
self.target_encoder = copy.deepcopy(self.online_encoder)
# Freeze the target encoder parameters (they will be updated via EMA)
for param in self.target_encoder.parameters():
param.requires_grad = False

@torch.no_grad()
def update_target_encoder(self, decay=0.998):
# EMA update for the target encoder
for param_o, param_t in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
param_t.data.mul_(decay).add_(param_o.data * (1 - decay))

def forward(self, input_obs, target_obs):
"""
Args:
input_obs: (B, N_in, 3, H, W) tensor of past frames
target_obs: (B, N_target, 3, H, W) tensor of future frames
Returns:
online_embeddings: (B, N_in, encoder_feat_dim)
target_embeddings: (B, N_target, encoder_feat_dim)
"""
# Process input observations through online encoder
B, N_in, C, H, W = input_obs.shape
input_obs_flat = input_obs.view(B * N_in, C, H, W)
online_embeddings_flat = self.online_encoder(input_obs_flat)
online_embeddings = online_embeddings_flat.view(B, N_in, -1)

# Process target observations through target encoder with stop gradient
with torch.no_grad():
B, N_target, C, H, W = target_obs.shape
target_obs_flat = target_obs.view(B * N_target, C, H, W)
target_embeddings_flat = self.target_encoder(target_obs_flat)
target_embeddings = target_embeddings_flat.view(B, N_target, -1)

return online_embeddings, target_embeddings
33 changes: 33 additions & 0 deletions pl_modules/citywalk_pretrain_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# data/datamodule.py

import pytorch_lightning as pl
from torch.utils.data import DataLoader
from data.citywalk_dataset import CityWalkSampler
from data.citywalk_pretrain_dataset import VideoJEPADataset

class CityWalkPretrainDataModule(pl.LightningDataModule):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.batch_size = cfg.training.batch_size
self.num_workers = cfg.data.num_workers

def setup(self, stage=None):
if stage == 'fit' or stage is None:
self.train_dataset = VideoJEPADataset(self.cfg, mode='train')
self.val_dataset = VideoJEPADataset(self.cfg, mode='val')

if stage == 'test' or stage is None:
self.test_dataset = VideoJEPADataset(self.cfg, mode='test')

def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size,
num_workers=self.num_workers, sampler=CityWalkSampler(self.train_dataset))

def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size,
num_workers=self.num_workers, shuffle=False)

def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batch_size,
num_workers=self.num_workers, shuffle=False)
Loading

0 comments on commit 5735c8e

Please sign in to comment.